BerserkerMother commited on
Commit
7baf5b5
1 Parent(s): 6bd6a70

Adds Flan-T5 seq2seq training

Browse files
elise/src/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .train_t5 import T5TrainingConfig
elise/src/configs/train_t5.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class T5TrainingConfig:
6
+ """Training configs for T5 finetuing"""
7
+
8
+ train_batch_size: int = 32
9
+ eval_batch_size: int = 32
10
+ epochs: int = 10
11
+ max_length: int = 512
12
+ learning_rate: float = 3e-4
13
+ num_warmup_steps: int = 200
14
+ mixed_precision: str = "bf16"
15
+ gradient_accumulation_steps: int = 4
16
+ output_dir: str = "FlanT5_MIT_ner"
elise/src/data/__init__.py CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Contians datasets and their connectors for model training
3
+ """
4
+
5
+ from .mit_seq2seq_dataset import MITRestaurants, get_default_transforms
elise/src/data/mit_seq2seq_dataset.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ seq2seq models datasets
3
+
4
+ Classes:
5
+ MITRestaurants: tner/mit_restaurant dataset to seq2seq
6
+
7
+ Functions:
8
+ get_default_transforms: default transforms for mit dataset
9
+ """
10
+ import datasets
11
+
12
+
13
+ class MITRestaurants:
14
+ """
15
+ tner/mit_restaurants for seq2seq
16
+
17
+ Atrributes
18
+ ----------
19
+ ner_tags: ner tags and ids of mit restaurant
20
+ dataset: hf dataset
21
+ transforms: transforms to apply
22
+ """
23
+
24
+ ner_tags = {
25
+ "O": 0,
26
+ "B-Rating": 1,
27
+ "I-Rating": 2,
28
+ "B-Amenity": 3,
29
+ "I-Amenity": 4,
30
+ "B-Location": 5,
31
+ "I-Location": 6,
32
+ "B-Restaurant_Name": 7,
33
+ "I-Restaurant_Name": 8,
34
+ "B-Price": 9,
35
+ "B-Hours": 10,
36
+ "I-Hours": 11,
37
+ "B-Dish": 12,
38
+ "I-Dish": 13,
39
+ "B-Cuisine": 14,
40
+ "I-Price": 15,
41
+ "I-Cuisine": 16,
42
+ }
43
+
44
+ def __init__(self, dataset: datasets.DatasetDict, transforms=None):
45
+ """
46
+ Constructs mit datasets
47
+
48
+ Parameters:
49
+ dataset: huggingface mit dataset
50
+ transforms: dataset transform functions
51
+ """
52
+ self.dataset = dataset
53
+ self.transforms = transforms
54
+
55
+ def hf_training(self):
56
+ """
57
+ Returns dataset for huggingface training ecosystem
58
+ """
59
+ dataset_ = self.dataset
60
+ if self.transforms:
61
+ for transfrom in self.transforms:
62
+ dataset_ = dataset_.map(transfrom)
63
+ return dataset_
64
+
65
+ def set_transforms(self, transforms):
66
+ """
67
+ Set tranfroms fn
68
+
69
+ Parameters:
70
+ transforms: transforms functions
71
+ """
72
+ if self.transforms:
73
+ self.transforms += transforms
74
+ else:
75
+ self.transforms = transforms
76
+ return self
77
+
78
+ @classmethod
79
+ def from_hf(cls, hf_path: str):
80
+ """
81
+ Constructs dataset from huggingface
82
+
83
+ Parameters:
84
+ hf_path: path to dataset hf repo
85
+ """
86
+ return cls(datasets.load_dataset(hf_path))
87
+
88
+
89
+ def get_default_transforms():
90
+ label_names = {v: k for k, v in MITRestaurants.ner_tags.items()}
91
+
92
+ def decode_tags(tags, words):
93
+ dict_out = {}
94
+ word_ = ""
95
+ for tag, word in zip(tags[::-1], words[::-1]):
96
+ if tag == 0:
97
+ continue
98
+ word_ = word + " " + word_
99
+ if label_names[tag].startswith("B"):
100
+ tag_name = label_names[tag][2:]
101
+ word_ = word_.strip()
102
+ if tag_name not in dict_out:
103
+ dict_out[tag_name] = [word_]
104
+ else:
105
+ dict_out[tag_name].append(word_)
106
+ word_ = ""
107
+ return dict_out
108
+
109
+ def format_to_text(decoded):
110
+ text = ""
111
+ for key, value in decoded.items():
112
+ text += f"{key}: {', '.join(value)}\n"
113
+ return text
114
+
115
+ def generate_seq2seq_data(example):
116
+ decoded = decode_tags(example["tags"], example["tokens"])
117
+ return {
118
+ "tokens": " ".join(example["tokens"]),
119
+ "labels": format_to_text(decoded),
120
+ }
121
+
122
+ return [generate_seq2seq_data]
elise/src/data/t5_dataset.py DELETED
File without changes
elise/src/train_t5_seq2seq.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import evaluate
3
+ import datasets
4
+ from torch.utils.data import DataLoader
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from dataclasses import asdict
8
+
9
+ from transformers import DataCollatorForSeq2Seq
10
+ from accelerate import Accelerator
11
+ from transformers import get_scheduler
12
+ import numpy as np
13
+ import mlflow
14
+
15
+ from tqdm.auto import tqdm
16
+
17
+ from data import MITRestaurants, get_default_transforms
18
+ from utils.logger import get_logger
19
+ from configs import T5TrainingConfig
20
+
21
+ log = get_logger("Flan_T5")
22
+ log.debug("heloooooooooooo?")
23
+
24
+ # get dataset
25
+ transforms = get_default_transforms()
26
+ dataset = (
27
+ MITRestaurants.from_hf("tner/mit_restaurant")
28
+ .set_transforms(transforms)
29
+ .hf_training()
30
+ )
31
+ dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
32
+ # log.info(dataset)
33
+ print(dataset)
34
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
35
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
36
+
37
+
38
+ def tokenize(example):
39
+ tokenized = tokenizer(
40
+ example["tokens"],
41
+ text_target=example["labels"],
42
+ max_length=512,
43
+ truncation=True,
44
+ )
45
+
46
+ return tokenized
47
+
48
+
49
+ tokenized_datasets = dataset.map(
50
+ tokenize,
51
+ batched=True,
52
+ remove_columns=dataset["train"].column_names,
53
+ )
54
+
55
+ # bleu metric
56
+ metric = evaluate.load("sacrebleu")
57
+
58
+
59
+ def postprocess(predictions, labels):
60
+ predictions = predictions.cpu().numpy()
61
+ labels = labels.cpu().numpy()
62
+
63
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
64
+
65
+ # Replace -100 in the labels as we can't decode them.
66
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
67
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
68
+
69
+ # Some simple post-processing
70
+ decoded_preds = [pred.strip() for pred in decoded_preds]
71
+ decoded_labels = [[label.strip()] for label in decoded_labels]
72
+ return decoded_preds, decoded_labels
73
+
74
+
75
+ config = T5TrainingConfig()
76
+
77
+ # data collator
78
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
79
+
80
+ # data loaders
81
+ tokenized_datasets.set_format("torch")
82
+ train_dataloader = DataLoader(
83
+ tokenized_datasets["train"],
84
+ shuffle=True,
85
+ collate_fn=data_collator,
86
+ batch_size=config.train_batch_size,
87
+ )
88
+ eval_dataloader = DataLoader(
89
+ tokenized_datasets["validation"],
90
+ collate_fn=data_collator,
91
+ batch_size=config.eval_batch_size,
92
+ )
93
+
94
+ # optimizer
95
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
96
+ num_update_steps_per_epoch = len(train_dataloader)
97
+ num_training_steps = config.epochs * num_update_steps_per_epoch
98
+
99
+ lr_scheduler = get_scheduler(
100
+ "linear",
101
+ optimizer=optimizer,
102
+ num_warmup_steps=config.num_warmup_steps,
103
+ num_training_steps=num_training_steps,
104
+ )
105
+
106
+ # accelerator
107
+ accelerator = Accelerator(
108
+ mixed_precision=config.mixed_precision,
109
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
110
+ )
111
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
112
+ model, optimizer, train_dataloader, eval_dataloader
113
+ )
114
+
115
+ progress_bar = tqdm(range(num_training_steps))
116
+
117
+
118
+ def train(model, dataset, metric):
119
+ # log.info("Starting Training")
120
+ print("Starting Traning")
121
+ for epoch in range(config.epochs):
122
+ # Training
123
+ model.train()
124
+ for batch in train_dataloader:
125
+ with accelerator.accumulate(model):
126
+ outputs = model(**batch)
127
+ loss = outputs.loss
128
+ accelerator.backward(loss)
129
+
130
+ optimizer.step()
131
+ lr_scheduler.step()
132
+ optimizer.zero_grad()
133
+ progress_bar.update(1)
134
+
135
+ # Evaluation
136
+ model.eval()
137
+ for batch in tqdm(eval_dataloader):
138
+ with torch.no_grad():
139
+ generated_tokens = accelerator.unwrap_model(model).generate(
140
+ batch["input_ids"],
141
+ attention_mask=batch["attention_mask"],
142
+ max_length=128,
143
+ )
144
+ labels = batch["labels"]
145
+
146
+ # Necessary to pad predictions and labels for being gathered
147
+ generated_tokens = accelerator.pad_across_processes(
148
+ generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
149
+ )
150
+ labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
151
+
152
+ predictions_gathered = accelerator.gather(generated_tokens)
153
+ labels_gathered = accelerator.gather(labels)
154
+
155
+ decoded_preds, decoded_labels = postprocess(
156
+ predictions_gathered, labels_gathered
157
+ )
158
+ metric.add_batch(predictions=decoded_preds, references=decoded_labels)
159
+
160
+ results = metric.compute()
161
+ mlflow.log_metrics({"epoch": epoch, "BLEU score": results["score"]})
162
+ print(f"epoch {epoch}, BLEU score: {results['score']:.2f}")
163
+
164
+ # Save and upload
165
+ accelerator.wait_for_everyone()
166
+ unwrapped_model = accelerator.unwrap_model(model)
167
+ unwrapped_model.save_pretrained(
168
+ config.output_dir, save_function=accelerator.save
169
+ )
170
+ if accelerator.is_main_process:
171
+ tokenizer.save_pretrained(config.output_dir)
172
+ # save model with mlflow
173
+ mlflow.transformers.log_model(
174
+ transformers_model={"model": unwrapped_model, "tokenizer": tokenizer},
175
+ task="text2text-generation",
176
+ artifact_path="seq2seq_model",
177
+ registered_model_name="FlanT5_MIT"
178
+ )
179
+
180
+ mlflow.set_tracking_uri("http://127.0.0.1:5000")
181
+ with mlflow.start_run() as mlflow_run:
182
+ mlflow.log_params(asdict(config))
183
+ train(model, tokenized_datasets, metric)