Aftrhour_demo / elise /src /excutors /trainer_seq2seq.py
BerserkerMother's picture
Adds initial files for seq2seq training
ff8f746
raw
history blame
6.27 kB
from transformers import get_scheduler
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import DataCollatorForTokenClassification
from accelerate import Accelerator
import evaluate
import datasets
from tqdm.auto import tqdm
ner_tags = {
"O": 0,
"B-Rating": 1,
"I-Rating": 2,
"B-Amenity": 3,
"I-Amenity": 4,
"B-Location": 5,
"I-Location": 6,
"B-Restaurant_Name": 7,
"I-Restaurant_Name": 8,
"B-Price": 9,
"B-Hours": 10,
"I-Hours": 11,
"B-Dish": 12,
"I-Dish": 13,
"B-Cuisine": 14,
"I-Price": 15,
"I-Cuisine": 16,
}
label_names = {v: k for k, v in ner_tags.items()}
# dataset aggregation
dataset = load_dataset("tner/mit_restaurant")
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["validation"]])
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
print(dataset)
tokenizer = AutoTokenizer.from_pretrained(
'sentence-transformers/all-MiniLM-L6-v2')
def align_labels_with_tokens(labels, word_ids):
new_labels = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Start of a new word!
current_word = word_id
label = -100 if word_id is None else labels[word_id]
new_labels.append(label)
elif word_id is None:
# Special token
new_labels.append(-100)
else:
# Same word as previous token
label = labels[word_id]
# If the label is B-XXX we change it to I-XXX
label_name = label_names[label]
if label_name.startswith("B"):
label = ner_tags["I" + label_name[1:]]
new_labels.append(label)
return new_labels
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(
examples["tokens"], truncation=True, is_split_into_words=True
)
all_labels = examples["tags"]
new_labels = []
for i, labels in enumerate(all_labels):
word_ids = tokenized_inputs.word_ids(i)
new_labels.append(align_labels_with_tokens(labels, word_ids))
tokenized_inputs["labels"] = new_labels
return tokenized_inputs
tokenized_datasets = dataset.map(
tokenize_and_align_labels,
batched=True,
remove_columns=dataset["train"].column_names,
)
def train():
metric = evaluate.load("seqeval")
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
train_dataloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
collate_fn=data_collator,
batch_size=128,
)
eval_dataloader = DataLoader(
tokenized_datasets["test"],
collate_fn=data_collator,
batch_size=8
)
model = AutoModelForTokenClassification.from_pretrained(
'sentence-transformers/all-MiniLM-L6-v2',
id2label=label_names,
label2id=ner_tags,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
num_train_epochs = 50
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps,
)
def postprocess(predictions, labels):
predictions = predictions.detach().cpu().clone().numpy()
labels = labels.detach().cpu().clone().numpy()
# Remove ignored index (special tokens) and convert to labels
true_labels = [[label_names[l] for l in label if l != -100]
for label in labels]
true_predictions = [
[label_names[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
return true_labels, true_predictions
progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_train_epochs):
# Training
model.train()
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
# Evaluation
model.eval()
for batch in eval_dataloader:
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
labels = batch["labels"]
# Necessary to pad predictions and labels for being gathered
predictions = accelerator.pad_across_processes(
predictions, dim=1, pad_index=-100)
labels = accelerator.pad_across_processes(
labels, dim=1, pad_index=-100)
predictions_gathered = accelerator.gather(predictions)
labels_gathered = accelerator.gather(labels)
true_predictions, true_labels = postprocess(
predictions_gathered, labels_gathered)
metric.add_batch(predictions=true_predictions,
references=true_labels)
results = metric.compute()
print(
f"epoch {epoch}:",
{
key: results[f"overall_{key}"]
for key in ["precision", "recall", "f1", "accuracy"]
},
)
output_dir = "restaurant_ner"
# Save and upload
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
tokenizer.save_pretrained(output_dir)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
train()