Aftrhour_demo / elise /src /train_t5_seq2seq.py
BerserkerMother's picture
Adds Flan-T5 seq2seq training
7baf5b5
raw
history blame
5.68 kB
import torch
import evaluate
import datasets
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from dataclasses import asdict
from transformers import DataCollatorForSeq2Seq
from accelerate import Accelerator
from transformers import get_scheduler
import numpy as np
import mlflow
from tqdm.auto import tqdm
from data import MITRestaurants, get_default_transforms
from utils.logger import get_logger
from configs import T5TrainingConfig
log = get_logger("Flan_T5")
log.debug("heloooooooooooo?")
# get dataset
transforms = get_default_transforms()
dataset = (
MITRestaurants.from_hf("tner/mit_restaurant")
.set_transforms(transforms)
.hf_training()
)
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
# log.info(dataset)
print(dataset)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
def tokenize(example):
tokenized = tokenizer(
example["tokens"],
text_target=example["labels"],
max_length=512,
truncation=True,
)
return tokenized
tokenized_datasets = dataset.map(
tokenize,
batched=True,
remove_columns=dataset["train"].column_names,
)
# bleu metric
metric = evaluate.load("sacrebleu")
def postprocess(predictions, labels):
predictions = predictions.cpu().numpy()
labels = labels.cpu().numpy()
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]
return decoded_preds, decoded_labels
config = T5TrainingConfig()
# data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
# data loaders
tokenized_datasets.set_format("torch")
train_dataloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
collate_fn=data_collator,
batch_size=config.train_batch_size,
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"],
collate_fn=data_collator,
batch_size=config.eval_batch_size,
)
# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = config.epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=config.num_warmup_steps,
num_training_steps=num_training_steps,
)
# accelerator
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
progress_bar = tqdm(range(num_training_steps))
def train(model, dataset, metric):
# log.info("Starting Training")
print("Starting Traning")
for epoch in range(config.epochs):
# Training
model.train()
for batch in train_dataloader:
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
# Evaluation
model.eval()
for batch in tqdm(eval_dataloader):
with torch.no_grad():
generated_tokens = accelerator.unwrap_model(model).generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=128,
)
labels = batch["labels"]
# Necessary to pad predictions and labels for being gathered
generated_tokens = accelerator.pad_across_processes(
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
)
labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
predictions_gathered = accelerator.gather(generated_tokens)
labels_gathered = accelerator.gather(labels)
decoded_preds, decoded_labels = postprocess(
predictions_gathered, labels_gathered
)
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
results = metric.compute()
mlflow.log_metrics({"epoch": epoch, "BLEU score": results["score"]})
print(f"epoch {epoch}, BLEU score: {results['score']:.2f}")
# Save and upload
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
config.output_dir, save_function=accelerator.save
)
if accelerator.is_main_process:
tokenizer.save_pretrained(config.output_dir)
# save model with mlflow
mlflow.transformers.log_model(
transformers_model={"model": unwrapped_model, "tokenizer": tokenizer},
task="text2text-generation",
artifact_path="seq2seq_model",
registered_model_name="FlanT5_MIT"
)
mlflow.set_tracking_uri("http://127.0.0.1:5000")
with mlflow.start_run() as mlflow_run:
mlflow.log_params(asdict(config))
train(model, tokenized_datasets, metric)