Spaces:
Runtime error
Runtime error
""" | |
seq2seq models datasets | |
Classes: | |
MITRestaurants: tner/mit_restaurant dataset to seq2seq | |
Functions: | |
get_default_transforms: default transforms for mit dataset | |
""" | |
import datasets | |
class MITRestaurants: | |
""" | |
tner/mit_restaurants for seq2seq | |
Atrributes | |
---------- | |
ner_tags: ner tags and ids of mit restaurant | |
dataset: hf dataset | |
transforms: transforms to apply | |
""" | |
ner_tags = { | |
"O": 0, | |
"B-Rating": 1, | |
"I-Rating": 2, | |
"B-Amenity": 3, | |
"I-Amenity": 4, | |
"B-Location": 5, | |
"I-Location": 6, | |
"B-Restaurant_Name": 7, | |
"I-Restaurant_Name": 8, | |
"B-Price": 9, | |
"B-Hours": 10, | |
"I-Hours": 11, | |
"B-Dish": 12, | |
"I-Dish": 13, | |
"B-Cuisine": 14, | |
"I-Price": 15, | |
"I-Cuisine": 16, | |
} | |
def __init__(self, dataset: datasets.DatasetDict, transforms=None): | |
""" | |
Constructs mit datasets | |
Parameters: | |
dataset: huggingface mit dataset | |
transforms: dataset transform functions | |
""" | |
self.dataset = dataset | |
self.transforms = transforms | |
def hf_training(self): | |
""" | |
Returns dataset for huggingface training ecosystem | |
""" | |
dataset_ = self.dataset | |
if self.transforms: | |
for transfrom in self.transforms: | |
dataset_ = dataset_.map(transfrom) | |
return dataset_ | |
def set_transforms(self, transforms): | |
""" | |
Set tranfroms fn | |
Parameters: | |
transforms: transforms functions | |
""" | |
if self.transforms: | |
self.transforms += transforms | |
else: | |
self.transforms = transforms | |
return self | |
def from_hf(cls, hf_path: str): | |
""" | |
Constructs dataset from huggingface | |
Parameters: | |
hf_path: path to dataset hf repo | |
""" | |
return cls(datasets.load_dataset(hf_path)) | |
def get_default_transforms(): | |
label_names = {v: k for k, v in MITRestaurants.ner_tags.items()} | |
def decode_tags(tags, words): | |
dict_out = {} | |
word_ = "" | |
for tag, word in zip(tags[::-1], words[::-1]): | |
if tag == 0: | |
continue | |
word_ = word + " " + word_ | |
if label_names[tag].startswith("B"): | |
tag_name = label_names[tag][2:] | |
word_ = word_.strip() | |
if tag_name not in dict_out: | |
dict_out[tag_name] = [word_] | |
else: | |
dict_out[tag_name].append(word_) | |
word_ = "" | |
return dict_out | |
def format_to_text(decoded): | |
text = "" | |
for key, value in decoded.items(): | |
text += f"{key}: {', '.join(value)}\n" | |
return text | |
def generate_seq2seq_data(example): | |
decoded = decode_tags(example["tags"], example["tokens"]) | |
return { | |
"tokens": " ".join(example["tokens"]), | |
"labels": format_to_text(decoded), | |
} | |
return [generate_seq2seq_data] | |