from transformers import AutoTokenizer, AutoModelForSeq2SeqLM class Model: def __init__(self, revision) -> None: self.tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base") self.model = AutoModelForSeq2SeqLM.from_pretrained("truong-xuan-linh/vit5-reproduce", revision=revision) def preprocess_function(self, text): inputs = self.tokenizer( text, max_length=1024, truncation=True, padding=True, return_tensors="pt" ) return inputs def inference(self, text): max_target_length = 256 inputs = self.preprocess_function(text) outputs = self.model.generate( input_ids=inputs['input_ids'], max_length=max_target_length, attention_mask=inputs['attention_mask'], ) with self.tokenizer.as_target_tokenizer(): outputs = [self.tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs] return outputs[0]