Spaces:
Runtime error
Runtime error
BerserkerMother
commited on
Commit
•
7baf5b5
1
Parent(s):
6bd6a70
Adds Flan-T5 seq2seq training
Browse files- elise/src/configs/__init__.py +1 -0
- elise/src/configs/train_t5.py +16 -0
- elise/src/data/__init__.py +5 -0
- elise/src/data/mit_seq2seq_dataset.py +122 -0
- elise/src/data/t5_dataset.py +0 -0
- elise/src/train_t5_seq2seq.py +183 -0
elise/src/configs/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .train_t5 import T5TrainingConfig
|
elise/src/configs/train_t5.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class T5TrainingConfig:
|
6 |
+
"""Training configs for T5 finetuing"""
|
7 |
+
|
8 |
+
train_batch_size: int = 32
|
9 |
+
eval_batch_size: int = 32
|
10 |
+
epochs: int = 10
|
11 |
+
max_length: int = 512
|
12 |
+
learning_rate: float = 3e-4
|
13 |
+
num_warmup_steps: int = 200
|
14 |
+
mixed_precision: str = "bf16"
|
15 |
+
gradient_accumulation_steps: int = 4
|
16 |
+
output_dir: str = "FlanT5_MIT_ner"
|
elise/src/data/__init__.py
CHANGED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contians datasets and their connectors for model training
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .mit_seq2seq_dataset import MITRestaurants, get_default_transforms
|
elise/src/data/mit_seq2seq_dataset.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
seq2seq models datasets
|
3 |
+
|
4 |
+
Classes:
|
5 |
+
MITRestaurants: tner/mit_restaurant dataset to seq2seq
|
6 |
+
|
7 |
+
Functions:
|
8 |
+
get_default_transforms: default transforms for mit dataset
|
9 |
+
"""
|
10 |
+
import datasets
|
11 |
+
|
12 |
+
|
13 |
+
class MITRestaurants:
|
14 |
+
"""
|
15 |
+
tner/mit_restaurants for seq2seq
|
16 |
+
|
17 |
+
Atrributes
|
18 |
+
----------
|
19 |
+
ner_tags: ner tags and ids of mit restaurant
|
20 |
+
dataset: hf dataset
|
21 |
+
transforms: transforms to apply
|
22 |
+
"""
|
23 |
+
|
24 |
+
ner_tags = {
|
25 |
+
"O": 0,
|
26 |
+
"B-Rating": 1,
|
27 |
+
"I-Rating": 2,
|
28 |
+
"B-Amenity": 3,
|
29 |
+
"I-Amenity": 4,
|
30 |
+
"B-Location": 5,
|
31 |
+
"I-Location": 6,
|
32 |
+
"B-Restaurant_Name": 7,
|
33 |
+
"I-Restaurant_Name": 8,
|
34 |
+
"B-Price": 9,
|
35 |
+
"B-Hours": 10,
|
36 |
+
"I-Hours": 11,
|
37 |
+
"B-Dish": 12,
|
38 |
+
"I-Dish": 13,
|
39 |
+
"B-Cuisine": 14,
|
40 |
+
"I-Price": 15,
|
41 |
+
"I-Cuisine": 16,
|
42 |
+
}
|
43 |
+
|
44 |
+
def __init__(self, dataset: datasets.DatasetDict, transforms=None):
|
45 |
+
"""
|
46 |
+
Constructs mit datasets
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
dataset: huggingface mit dataset
|
50 |
+
transforms: dataset transform functions
|
51 |
+
"""
|
52 |
+
self.dataset = dataset
|
53 |
+
self.transforms = transforms
|
54 |
+
|
55 |
+
def hf_training(self):
|
56 |
+
"""
|
57 |
+
Returns dataset for huggingface training ecosystem
|
58 |
+
"""
|
59 |
+
dataset_ = self.dataset
|
60 |
+
if self.transforms:
|
61 |
+
for transfrom in self.transforms:
|
62 |
+
dataset_ = dataset_.map(transfrom)
|
63 |
+
return dataset_
|
64 |
+
|
65 |
+
def set_transforms(self, transforms):
|
66 |
+
"""
|
67 |
+
Set tranfroms fn
|
68 |
+
|
69 |
+
Parameters:
|
70 |
+
transforms: transforms functions
|
71 |
+
"""
|
72 |
+
if self.transforms:
|
73 |
+
self.transforms += transforms
|
74 |
+
else:
|
75 |
+
self.transforms = transforms
|
76 |
+
return self
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def from_hf(cls, hf_path: str):
|
80 |
+
"""
|
81 |
+
Constructs dataset from huggingface
|
82 |
+
|
83 |
+
Parameters:
|
84 |
+
hf_path: path to dataset hf repo
|
85 |
+
"""
|
86 |
+
return cls(datasets.load_dataset(hf_path))
|
87 |
+
|
88 |
+
|
89 |
+
def get_default_transforms():
|
90 |
+
label_names = {v: k for k, v in MITRestaurants.ner_tags.items()}
|
91 |
+
|
92 |
+
def decode_tags(tags, words):
|
93 |
+
dict_out = {}
|
94 |
+
word_ = ""
|
95 |
+
for tag, word in zip(tags[::-1], words[::-1]):
|
96 |
+
if tag == 0:
|
97 |
+
continue
|
98 |
+
word_ = word + " " + word_
|
99 |
+
if label_names[tag].startswith("B"):
|
100 |
+
tag_name = label_names[tag][2:]
|
101 |
+
word_ = word_.strip()
|
102 |
+
if tag_name not in dict_out:
|
103 |
+
dict_out[tag_name] = [word_]
|
104 |
+
else:
|
105 |
+
dict_out[tag_name].append(word_)
|
106 |
+
word_ = ""
|
107 |
+
return dict_out
|
108 |
+
|
109 |
+
def format_to_text(decoded):
|
110 |
+
text = ""
|
111 |
+
for key, value in decoded.items():
|
112 |
+
text += f"{key}: {', '.join(value)}\n"
|
113 |
+
return text
|
114 |
+
|
115 |
+
def generate_seq2seq_data(example):
|
116 |
+
decoded = decode_tags(example["tags"], example["tokens"])
|
117 |
+
return {
|
118 |
+
"tokens": " ".join(example["tokens"]),
|
119 |
+
"labels": format_to_text(decoded),
|
120 |
+
}
|
121 |
+
|
122 |
+
return [generate_seq2seq_data]
|
elise/src/data/t5_dataset.py
DELETED
File without changes
|
elise/src/train_t5_seq2seq.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import evaluate
|
3 |
+
import datasets
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from datasets import load_dataset
|
6 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
+
from dataclasses import asdict
|
8 |
+
|
9 |
+
from transformers import DataCollatorForSeq2Seq
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from transformers import get_scheduler
|
12 |
+
import numpy as np
|
13 |
+
import mlflow
|
14 |
+
|
15 |
+
from tqdm.auto import tqdm
|
16 |
+
|
17 |
+
from data import MITRestaurants, get_default_transforms
|
18 |
+
from utils.logger import get_logger
|
19 |
+
from configs import T5TrainingConfig
|
20 |
+
|
21 |
+
log = get_logger("Flan_T5")
|
22 |
+
log.debug("heloooooooooooo?")
|
23 |
+
|
24 |
+
# get dataset
|
25 |
+
transforms = get_default_transforms()
|
26 |
+
dataset = (
|
27 |
+
MITRestaurants.from_hf("tner/mit_restaurant")
|
28 |
+
.set_transforms(transforms)
|
29 |
+
.hf_training()
|
30 |
+
)
|
31 |
+
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
|
32 |
+
# log.info(dataset)
|
33 |
+
print(dataset)
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
35 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
|
36 |
+
|
37 |
+
|
38 |
+
def tokenize(example):
|
39 |
+
tokenized = tokenizer(
|
40 |
+
example["tokens"],
|
41 |
+
text_target=example["labels"],
|
42 |
+
max_length=512,
|
43 |
+
truncation=True,
|
44 |
+
)
|
45 |
+
|
46 |
+
return tokenized
|
47 |
+
|
48 |
+
|
49 |
+
tokenized_datasets = dataset.map(
|
50 |
+
tokenize,
|
51 |
+
batched=True,
|
52 |
+
remove_columns=dataset["train"].column_names,
|
53 |
+
)
|
54 |
+
|
55 |
+
# bleu metric
|
56 |
+
metric = evaluate.load("sacrebleu")
|
57 |
+
|
58 |
+
|
59 |
+
def postprocess(predictions, labels):
|
60 |
+
predictions = predictions.cpu().numpy()
|
61 |
+
labels = labels.cpu().numpy()
|
62 |
+
|
63 |
+
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
64 |
+
|
65 |
+
# Replace -100 in the labels as we can't decode them.
|
66 |
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
67 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
68 |
+
|
69 |
+
# Some simple post-processing
|
70 |
+
decoded_preds = [pred.strip() for pred in decoded_preds]
|
71 |
+
decoded_labels = [[label.strip()] for label in decoded_labels]
|
72 |
+
return decoded_preds, decoded_labels
|
73 |
+
|
74 |
+
|
75 |
+
config = T5TrainingConfig()
|
76 |
+
|
77 |
+
# data collator
|
78 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
79 |
+
|
80 |
+
# data loaders
|
81 |
+
tokenized_datasets.set_format("torch")
|
82 |
+
train_dataloader = DataLoader(
|
83 |
+
tokenized_datasets["train"],
|
84 |
+
shuffle=True,
|
85 |
+
collate_fn=data_collator,
|
86 |
+
batch_size=config.train_batch_size,
|
87 |
+
)
|
88 |
+
eval_dataloader = DataLoader(
|
89 |
+
tokenized_datasets["validation"],
|
90 |
+
collate_fn=data_collator,
|
91 |
+
batch_size=config.eval_batch_size,
|
92 |
+
)
|
93 |
+
|
94 |
+
# optimizer
|
95 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
96 |
+
num_update_steps_per_epoch = len(train_dataloader)
|
97 |
+
num_training_steps = config.epochs * num_update_steps_per_epoch
|
98 |
+
|
99 |
+
lr_scheduler = get_scheduler(
|
100 |
+
"linear",
|
101 |
+
optimizer=optimizer,
|
102 |
+
num_warmup_steps=config.num_warmup_steps,
|
103 |
+
num_training_steps=num_training_steps,
|
104 |
+
)
|
105 |
+
|
106 |
+
# accelerator
|
107 |
+
accelerator = Accelerator(
|
108 |
+
mixed_precision=config.mixed_precision,
|
109 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
110 |
+
)
|
111 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
112 |
+
model, optimizer, train_dataloader, eval_dataloader
|
113 |
+
)
|
114 |
+
|
115 |
+
progress_bar = tqdm(range(num_training_steps))
|
116 |
+
|
117 |
+
|
118 |
+
def train(model, dataset, metric):
|
119 |
+
# log.info("Starting Training")
|
120 |
+
print("Starting Traning")
|
121 |
+
for epoch in range(config.epochs):
|
122 |
+
# Training
|
123 |
+
model.train()
|
124 |
+
for batch in train_dataloader:
|
125 |
+
with accelerator.accumulate(model):
|
126 |
+
outputs = model(**batch)
|
127 |
+
loss = outputs.loss
|
128 |
+
accelerator.backward(loss)
|
129 |
+
|
130 |
+
optimizer.step()
|
131 |
+
lr_scheduler.step()
|
132 |
+
optimizer.zero_grad()
|
133 |
+
progress_bar.update(1)
|
134 |
+
|
135 |
+
# Evaluation
|
136 |
+
model.eval()
|
137 |
+
for batch in tqdm(eval_dataloader):
|
138 |
+
with torch.no_grad():
|
139 |
+
generated_tokens = accelerator.unwrap_model(model).generate(
|
140 |
+
batch["input_ids"],
|
141 |
+
attention_mask=batch["attention_mask"],
|
142 |
+
max_length=128,
|
143 |
+
)
|
144 |
+
labels = batch["labels"]
|
145 |
+
|
146 |
+
# Necessary to pad predictions and labels for being gathered
|
147 |
+
generated_tokens = accelerator.pad_across_processes(
|
148 |
+
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
|
149 |
+
)
|
150 |
+
labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
151 |
+
|
152 |
+
predictions_gathered = accelerator.gather(generated_tokens)
|
153 |
+
labels_gathered = accelerator.gather(labels)
|
154 |
+
|
155 |
+
decoded_preds, decoded_labels = postprocess(
|
156 |
+
predictions_gathered, labels_gathered
|
157 |
+
)
|
158 |
+
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
|
159 |
+
|
160 |
+
results = metric.compute()
|
161 |
+
mlflow.log_metrics({"epoch": epoch, "BLEU score": results["score"]})
|
162 |
+
print(f"epoch {epoch}, BLEU score: {results['score']:.2f}")
|
163 |
+
|
164 |
+
# Save and upload
|
165 |
+
accelerator.wait_for_everyone()
|
166 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
167 |
+
unwrapped_model.save_pretrained(
|
168 |
+
config.output_dir, save_function=accelerator.save
|
169 |
+
)
|
170 |
+
if accelerator.is_main_process:
|
171 |
+
tokenizer.save_pretrained(config.output_dir)
|
172 |
+
# save model with mlflow
|
173 |
+
mlflow.transformers.log_model(
|
174 |
+
transformers_model={"model": unwrapped_model, "tokenizer": tokenizer},
|
175 |
+
task="text2text-generation",
|
176 |
+
artifact_path="seq2seq_model",
|
177 |
+
registered_model_name="FlanT5_MIT"
|
178 |
+
)
|
179 |
+
|
180 |
+
mlflow.set_tracking_uri("http://127.0.0.1:5000")
|
181 |
+
with mlflow.start_run() as mlflow_run:
|
182 |
+
mlflow.log_params(asdict(config))
|
183 |
+
train(model, tokenized_datasets, metric)
|