Spaces:
Runtime error
Runtime error
import torch | |
import evaluate | |
import datasets | |
from torch.utils.data import DataLoader | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from dataclasses import asdict | |
from transformers import DataCollatorForSeq2Seq | |
from accelerate import Accelerator | |
from transformers import get_scheduler | |
import numpy as np | |
import mlflow | |
from tqdm.auto import tqdm | |
from data import MITRestaurants, get_default_transforms | |
from utils.logger import get_logger | |
from configs import T5TrainingConfig | |
log = get_logger("Flan_T5") | |
log.debug("heloooooooooooo?") | |
# get dataset | |
transforms = get_default_transforms() | |
dataset = ( | |
MITRestaurants.from_hf("tner/mit_restaurant") | |
.set_transforms(transforms) | |
.hf_training() | |
) | |
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]]) | |
# log.info(dataset) | |
print(dataset) | |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") | |
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") | |
def tokenize(example): | |
tokenized = tokenizer( | |
example["tokens"], | |
text_target=example["labels"], | |
max_length=512, | |
truncation=True, | |
) | |
return tokenized | |
tokenized_datasets = dataset.map( | |
tokenize, | |
batched=True, | |
remove_columns=dataset["train"].column_names, | |
) | |
# bleu metric | |
metric = evaluate.load("sacrebleu") | |
def postprocess(predictions, labels): | |
predictions = predictions.cpu().numpy() | |
labels = labels.cpu().numpy() | |
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) | |
# Replace -100 in the labels as we can't decode them. | |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
# Some simple post-processing | |
decoded_preds = [pred.strip() for pred in decoded_preds] | |
decoded_labels = [[label.strip()] for label in decoded_labels] | |
return decoded_preds, decoded_labels | |
config = T5TrainingConfig() | |
# data collator | |
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
# data loaders | |
tokenized_datasets.set_format("torch") | |
train_dataloader = DataLoader( | |
tokenized_datasets["train"], | |
shuffle=True, | |
collate_fn=data_collator, | |
batch_size=config.train_batch_size, | |
) | |
eval_dataloader = DataLoader( | |
tokenized_datasets["validation"], | |
collate_fn=data_collator, | |
batch_size=config.eval_batch_size, | |
) | |
# optimizer | |
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |
num_update_steps_per_epoch = len(train_dataloader) | |
num_training_steps = config.epochs * num_update_steps_per_epoch | |
lr_scheduler = get_scheduler( | |
"linear", | |
optimizer=optimizer, | |
num_warmup_steps=config.num_warmup_steps, | |
num_training_steps=num_training_steps, | |
) | |
# accelerator | |
accelerator = Accelerator( | |
mixed_precision=config.mixed_precision, | |
gradient_accumulation_steps=config.gradient_accumulation_steps, | |
) | |
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( | |
model, optimizer, train_dataloader, eval_dataloader | |
) | |
progress_bar = tqdm(range(num_training_steps)) | |
def train(model, dataset, metric): | |
# log.info("Starting Training") | |
print("Starting Traning") | |
for epoch in range(config.epochs): | |
# Training | |
model.train() | |
for batch in train_dataloader: | |
with accelerator.accumulate(model): | |
outputs = model(**batch) | |
loss = outputs.loss | |
accelerator.backward(loss) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
# Evaluation | |
model.eval() | |
for batch in tqdm(eval_dataloader): | |
with torch.no_grad(): | |
generated_tokens = accelerator.unwrap_model(model).generate( | |
batch["input_ids"], | |
attention_mask=batch["attention_mask"], | |
max_length=128, | |
) | |
labels = batch["labels"] | |
# Necessary to pad predictions and labels for being gathered | |
generated_tokens = accelerator.pad_across_processes( | |
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id | |
) | |
labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100) | |
predictions_gathered = accelerator.gather(generated_tokens) | |
labels_gathered = accelerator.gather(labels) | |
decoded_preds, decoded_labels = postprocess( | |
predictions_gathered, labels_gathered | |
) | |
metric.add_batch(predictions=decoded_preds, references=decoded_labels) | |
results = metric.compute() | |
mlflow.log_metrics({"epoch": epoch, "BLEU score": results["score"]}) | |
print(f"epoch {epoch}, BLEU score: {results['score']:.2f}") | |
# Save and upload | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained( | |
config.output_dir, save_function=accelerator.save | |
) | |
if accelerator.is_main_process: | |
tokenizer.save_pretrained(config.output_dir) | |
# save model with mlflow | |
mlflow.transformers.log_model( | |
transformers_model={"model": unwrapped_model, "tokenizer": tokenizer}, | |
task="text2text-generation", | |
artifact_path="seq2seq_model", | |
registered_model_name="FlanT5_MIT" | |
) | |
mlflow.set_tracking_uri("http://127.0.0.1:5000") | |
with mlflow.start_run() as mlflow_run: | |
mlflow.log_params(asdict(config)) | |
train(model, tokenized_datasets, metric) | |