Spaces:
Runtime error
Runtime error
from transformers import get_scheduler | |
import torch | |
from torch.utils.data import DataLoader | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
from transformers import DataCollatorForTokenClassification | |
from accelerate import Accelerator | |
import evaluate | |
import datasets | |
from tqdm.auto import tqdm | |
ner_tags = { | |
"O": 0, | |
"B-Rating": 1, | |
"I-Rating": 2, | |
"B-Amenity": 3, | |
"I-Amenity": 4, | |
"B-Location": 5, | |
"I-Location": 6, | |
"B-Restaurant_Name": 7, | |
"I-Restaurant_Name": 8, | |
"B-Price": 9, | |
"B-Hours": 10, | |
"I-Hours": 11, | |
"B-Dish": 12, | |
"I-Dish": 13, | |
"B-Cuisine": 14, | |
"I-Price": 15, | |
"I-Cuisine": 16, | |
} | |
label_names = {v: k for k, v in ner_tags.items()} | |
# dataset aggregation | |
dataset = load_dataset("tner/mit_restaurant") | |
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["validation"]]) | |
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]]) | |
print(dataset) | |
tokenizer = AutoTokenizer.from_pretrained( | |
'sentence-transformers/all-MiniLM-L6-v2') | |
def align_labels_with_tokens(labels, word_ids): | |
new_labels = [] | |
current_word = None | |
for word_id in word_ids: | |
if word_id != current_word: | |
# Start of a new word! | |
current_word = word_id | |
label = -100 if word_id is None else labels[word_id] | |
new_labels.append(label) | |
elif word_id is None: | |
# Special token | |
new_labels.append(-100) | |
else: | |
# Same word as previous token | |
label = labels[word_id] | |
# If the label is B-XXX we change it to I-XXX | |
label_name = label_names[label] | |
if label_name.startswith("B"): | |
label = ner_tags["I" + label_name[1:]] | |
new_labels.append(label) | |
return new_labels | |
def tokenize_and_align_labels(examples): | |
tokenized_inputs = tokenizer( | |
examples["tokens"], truncation=True, is_split_into_words=True | |
) | |
all_labels = examples["tags"] | |
new_labels = [] | |
for i, labels in enumerate(all_labels): | |
word_ids = tokenized_inputs.word_ids(i) | |
new_labels.append(align_labels_with_tokens(labels, word_ids)) | |
tokenized_inputs["labels"] = new_labels | |
return tokenized_inputs | |
tokenized_datasets = dataset.map( | |
tokenize_and_align_labels, | |
batched=True, | |
remove_columns=dataset["train"].column_names, | |
) | |
def train(): | |
metric = evaluate.load("seqeval") | |
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) | |
train_dataloader = DataLoader( | |
tokenized_datasets["train"], | |
shuffle=True, | |
collate_fn=data_collator, | |
batch_size=128, | |
) | |
eval_dataloader = DataLoader( | |
tokenized_datasets["test"], | |
collate_fn=data_collator, | |
batch_size=8 | |
) | |
model = AutoModelForTokenClassification.from_pretrained( | |
'sentence-transformers/all-MiniLM-L6-v2', | |
id2label=label_names, | |
label2id=ner_tags, | |
) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) | |
accelerator = Accelerator() | |
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( | |
model, optimizer, train_dataloader, eval_dataloader | |
) | |
num_train_epochs = 50 | |
num_update_steps_per_epoch = len(train_dataloader) | |
num_training_steps = num_train_epochs * num_update_steps_per_epoch | |
lr_scheduler = get_scheduler( | |
"linear", | |
optimizer=optimizer, | |
num_warmup_steps=0, | |
num_training_steps=num_training_steps, | |
) | |
def postprocess(predictions, labels): | |
predictions = predictions.detach().cpu().clone().numpy() | |
labels = labels.detach().cpu().clone().numpy() | |
# Remove ignored index (special tokens) and convert to labels | |
true_labels = [[label_names[l] for l in label if l != -100] | |
for label in labels] | |
true_predictions = [ | |
[label_names[p] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
return true_labels, true_predictions | |
progress_bar = tqdm(range(num_training_steps)) | |
for epoch in range(num_train_epochs): | |
# Training | |
model.train() | |
for batch in train_dataloader: | |
outputs = model(**batch) | |
loss = outputs.loss | |
accelerator.backward(loss) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
# Evaluation | |
model.eval() | |
for batch in eval_dataloader: | |
with torch.no_grad(): | |
outputs = model(**batch) | |
predictions = outputs.logits.argmax(dim=-1) | |
labels = batch["labels"] | |
# Necessary to pad predictions and labels for being gathered | |
predictions = accelerator.pad_across_processes( | |
predictions, dim=1, pad_index=-100) | |
labels = accelerator.pad_across_processes( | |
labels, dim=1, pad_index=-100) | |
predictions_gathered = accelerator.gather(predictions) | |
labels_gathered = accelerator.gather(labels) | |
true_predictions, true_labels = postprocess( | |
predictions_gathered, labels_gathered) | |
metric.add_batch(predictions=true_predictions, | |
references=true_labels) | |
results = metric.compute() | |
print( | |
f"epoch {epoch}:", | |
{ | |
key: results[f"overall_{key}"] | |
for key in ["precision", "recall", "f1", "accuracy"] | |
}, | |
) | |
output_dir = "restaurant_ner" | |
# Save and upload | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained( | |
output_dir, save_function=accelerator.save) | |
if accelerator.is_main_process: | |
tokenizer.save_pretrained(output_dir) | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) | |
train() |