File size: 3,102 Bytes
7baf5b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
seq2seq models datasets

Classes:
    MITRestaurants: tner/mit_restaurant dataset to seq2seq

Functions: 
    get_default_transforms: default transforms for mit dataset
"""
import datasets


class MITRestaurants:
    """
    tner/mit_restaurants for seq2seq

    Atrributes
    ----------
    ner_tags: ner tags and ids of mit restaurant
    dataset: hf dataset
    transforms: transforms to apply
    """

    ner_tags = {
        "O": 0,
        "B-Rating": 1,
        "I-Rating": 2,
        "B-Amenity": 3,
        "I-Amenity": 4,
        "B-Location": 5,
        "I-Location": 6,
        "B-Restaurant_Name": 7,
        "I-Restaurant_Name": 8,
        "B-Price": 9,
        "B-Hours": 10,
        "I-Hours": 11,
        "B-Dish": 12,
        "I-Dish": 13,
        "B-Cuisine": 14,
        "I-Price": 15,
        "I-Cuisine": 16,
    }

    def __init__(self, dataset: datasets.DatasetDict, transforms=None):
        """
        Constructs mit datasets

        Parameters:
            dataset: huggingface mit dataset
            transforms: dataset transform functions
        """
        self.dataset = dataset
        self.transforms = transforms

    def hf_training(self):
        """
        Returns dataset for huggingface training ecosystem
        """
        dataset_ = self.dataset
        if self.transforms:
            for transfrom in self.transforms:
                dataset_ = dataset_.map(transfrom)
        return dataset_

    def set_transforms(self, transforms):
        """
        Set tranfroms fn

        Parameters:
            transforms: transforms functions
        """
        if self.transforms:
            self.transforms += transforms
        else:
            self.transforms = transforms
        return self

    @classmethod
    def from_hf(cls, hf_path: str):
        """
        Constructs dataset from huggingface

        Parameters:
            hf_path: path to dataset hf repo
        """
        return cls(datasets.load_dataset(hf_path))


def get_default_transforms():
    label_names = {v: k for k, v in MITRestaurants.ner_tags.items()}

    def decode_tags(tags, words):
        dict_out = {}
        word_ = ""
        for tag, word in zip(tags[::-1], words[::-1]):
            if tag == 0:
                continue
            word_ = word + " " + word_
            if label_names[tag].startswith("B"):
                tag_name = label_names[tag][2:]
                word_ = word_.strip()
                if tag_name not in dict_out:
                    dict_out[tag_name] = [word_]
                else:
                    dict_out[tag_name].append(word_)
                word_ = ""
        return dict_out

    def format_to_text(decoded):
        text = ""
        for key, value in decoded.items():
            text += f"{key}: {', '.join(value)}\n"
        return text

    def generate_seq2seq_data(example):
        decoded = decode_tags(example["tags"], example["tokens"])
        return {
            "tokens": " ".join(example["tokens"]),
            "labels": format_to_text(decoded),
        }

    return [generate_seq2seq_data]