Aftrhour_demo / elise /src /data /mit_seq2seq_dataset.py
BerserkerMother's picture
Adds Flan-T5 seq2seq training
7baf5b5
raw
history blame
3.1 kB
"""
seq2seq models datasets
Classes:
MITRestaurants: tner/mit_restaurant dataset to seq2seq
Functions:
get_default_transforms: default transforms for mit dataset
"""
import datasets
class MITRestaurants:
"""
tner/mit_restaurants for seq2seq
Atrributes
----------
ner_tags: ner tags and ids of mit restaurant
dataset: hf dataset
transforms: transforms to apply
"""
ner_tags = {
"O": 0,
"B-Rating": 1,
"I-Rating": 2,
"B-Amenity": 3,
"I-Amenity": 4,
"B-Location": 5,
"I-Location": 6,
"B-Restaurant_Name": 7,
"I-Restaurant_Name": 8,
"B-Price": 9,
"B-Hours": 10,
"I-Hours": 11,
"B-Dish": 12,
"I-Dish": 13,
"B-Cuisine": 14,
"I-Price": 15,
"I-Cuisine": 16,
}
def __init__(self, dataset: datasets.DatasetDict, transforms=None):
"""
Constructs mit datasets
Parameters:
dataset: huggingface mit dataset
transforms: dataset transform functions
"""
self.dataset = dataset
self.transforms = transforms
def hf_training(self):
"""
Returns dataset for huggingface training ecosystem
"""
dataset_ = self.dataset
if self.transforms:
for transfrom in self.transforms:
dataset_ = dataset_.map(transfrom)
return dataset_
def set_transforms(self, transforms):
"""
Set tranfroms fn
Parameters:
transforms: transforms functions
"""
if self.transforms:
self.transforms += transforms
else:
self.transforms = transforms
return self
@classmethod
def from_hf(cls, hf_path: str):
"""
Constructs dataset from huggingface
Parameters:
hf_path: path to dataset hf repo
"""
return cls(datasets.load_dataset(hf_path))
def get_default_transforms():
label_names = {v: k for k, v in MITRestaurants.ner_tags.items()}
def decode_tags(tags, words):
dict_out = {}
word_ = ""
for tag, word in zip(tags[::-1], words[::-1]):
if tag == 0:
continue
word_ = word + " " + word_
if label_names[tag].startswith("B"):
tag_name = label_names[tag][2:]
word_ = word_.strip()
if tag_name not in dict_out:
dict_out[tag_name] = [word_]
else:
dict_out[tag_name].append(word_)
word_ = ""
return dict_out
def format_to_text(decoded):
text = ""
for key, value in decoded.items():
text += f"{key}: {', '.join(value)}\n"
return text
def generate_seq2seq_data(example):
decoded = decode_tags(example["tags"], example["tokens"])
return {
"tokens": " ".join(example["tokens"]),
"labels": format_to_text(decoded),
}
return [generate_seq2seq_data]