""" seq2seq models datasets Classes: MITRestaurants: tner/mit_restaurant dataset to seq2seq Functions: get_default_transforms: default transforms for mit dataset """ import datasets class MITRestaurants: """ tner/mit_restaurants for seq2seq Atrributes ---------- ner_tags: ner tags and ids of mit restaurant dataset: hf dataset transforms: transforms to apply """ ner_tags = { "O": 0, "B-Rating": 1, "I-Rating": 2, "B-Amenity": 3, "I-Amenity": 4, "B-Location": 5, "I-Location": 6, "B-Restaurant_Name": 7, "I-Restaurant_Name": 8, "B-Price": 9, "B-Hours": 10, "I-Hours": 11, "B-Dish": 12, "I-Dish": 13, "B-Cuisine": 14, "I-Price": 15, "I-Cuisine": 16, } def __init__(self, dataset: datasets.DatasetDict, transforms=None): """ Constructs mit datasets Parameters: dataset: huggingface mit dataset transforms: dataset transform functions """ self.dataset = dataset self.transforms = transforms def hf_training(self): """ Returns dataset for huggingface training ecosystem """ dataset_ = self.dataset if self.transforms: for transfrom in self.transforms: dataset_ = dataset_.map(transfrom) return dataset_ def set_transforms(self, transforms): """ Set tranfroms fn Parameters: transforms: transforms functions """ if self.transforms: self.transforms += transforms else: self.transforms = transforms return self @classmethod def from_hf(cls, hf_path: str): """ Constructs dataset from huggingface Parameters: hf_path: path to dataset hf repo """ return cls(datasets.load_dataset(hf_path)) def get_default_transforms(): label_names = {v: k for k, v in MITRestaurants.ner_tags.items()} def decode_tags(tags, words): dict_out = {} word_ = "" for tag, word in zip(tags[::-1], words[::-1]): if tag == 0: continue word_ = word + " " + word_ if label_names[tag].startswith("B"): tag_name = label_names[tag][2:] word_ = word_.strip() if tag_name not in dict_out: dict_out[tag_name] = [word_] else: dict_out[tag_name].append(word_) word_ = "" return dict_out def format_to_text(decoded): text = "" for key, value in decoded.items(): text += f"{key}: {', '.join(value)}\n" return text def generate_seq2seq_data(example): decoded = decode_tags(example["tags"], example["tokens"]) return { "tokens": " ".join(example["tokens"]), "labels": format_to_text(decoded), } return [generate_seq2seq_data]