File size: 1,040 Bytes
e276af2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

class Model:
    def __init__(self, revision) -> None:
        
        self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")  
        self.model = AutoModelForSeq2SeqLM.from_pretrained("truong-xuan-linh/vit5-reproduce", revision=revision)
    
    def preprocess_function(self, text):
        inputs = self.tokenizer(
            text, max_length=1024, truncation=True, padding=True, return_tensors="pt"
        )
        return inputs
    
    def inference(self, text):
        max_target_length = 256
        inputs = self.preprocess_function(text)
        outputs = self.model.generate(
            input_ids=inputs['input_ids'],
            max_length=max_target_length,
            attention_mask=inputs['attention_mask'],
        )

        with self.tokenizer.as_target_tokenizer():
            outputs = [self.tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]
            
        return outputs[0]