Spaces:
Runtime error
Runtime error
File size: 3,102 Bytes
7baf5b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
"""
seq2seq models datasets
Classes:
MITRestaurants: tner/mit_restaurant dataset to seq2seq
Functions:
get_default_transforms: default transforms for mit dataset
"""
import datasets
class MITRestaurants:
"""
tner/mit_restaurants for seq2seq
Atrributes
----------
ner_tags: ner tags and ids of mit restaurant
dataset: hf dataset
transforms: transforms to apply
"""
ner_tags = {
"O": 0,
"B-Rating": 1,
"I-Rating": 2,
"B-Amenity": 3,
"I-Amenity": 4,
"B-Location": 5,
"I-Location": 6,
"B-Restaurant_Name": 7,
"I-Restaurant_Name": 8,
"B-Price": 9,
"B-Hours": 10,
"I-Hours": 11,
"B-Dish": 12,
"I-Dish": 13,
"B-Cuisine": 14,
"I-Price": 15,
"I-Cuisine": 16,
}
def __init__(self, dataset: datasets.DatasetDict, transforms=None):
"""
Constructs mit datasets
Parameters:
dataset: huggingface mit dataset
transforms: dataset transform functions
"""
self.dataset = dataset
self.transforms = transforms
def hf_training(self):
"""
Returns dataset for huggingface training ecosystem
"""
dataset_ = self.dataset
if self.transforms:
for transfrom in self.transforms:
dataset_ = dataset_.map(transfrom)
return dataset_
def set_transforms(self, transforms):
"""
Set tranfroms fn
Parameters:
transforms: transforms functions
"""
if self.transforms:
self.transforms += transforms
else:
self.transforms = transforms
return self
@classmethod
def from_hf(cls, hf_path: str):
"""
Constructs dataset from huggingface
Parameters:
hf_path: path to dataset hf repo
"""
return cls(datasets.load_dataset(hf_path))
def get_default_transforms():
label_names = {v: k for k, v in MITRestaurants.ner_tags.items()}
def decode_tags(tags, words):
dict_out = {}
word_ = ""
for tag, word in zip(tags[::-1], words[::-1]):
if tag == 0:
continue
word_ = word + " " + word_
if label_names[tag].startswith("B"):
tag_name = label_names[tag][2:]
word_ = word_.strip()
if tag_name not in dict_out:
dict_out[tag_name] = [word_]
else:
dict_out[tag_name].append(word_)
word_ = ""
return dict_out
def format_to_text(decoded):
text = ""
for key, value in decoded.items():
text += f"{key}: {', '.join(value)}\n"
return text
def generate_seq2seq_data(example):
decoded = decode_tags(example["tags"], example["tokens"])
return {
"tokens": " ".join(example["tokens"]),
"labels": format_to_text(decoded),
}
return [generate_seq2seq_data]
|