kowalsky commited on
Commit
30e1793
1 Parent(s): c8e808e

first commit

Browse files
__pycache__/data_loader.cpython-311.pyc ADDED
Binary file (2.27 kB). View file
 
__pycache__/metrics.cpython-311.pyc ADDED
Binary file (3.13 kB). View file
 
__pycache__/model.cpython-311.pyc ADDED
Binary file (9.38 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.75 kB). View file
 
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ import gradio as gr
5
+
6
+ project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ sys.path.append(project_dir)
8
+
9
+ from model import MultiTaskBertModel
10
+ from data_loader import load_dataset
11
+ from utils import bert_config, tokenizer, intent_ids_to_labels, intent_labels_to_ids
12
+
13
+ config = bert_config()
14
+ dataset = load_dataset("training_dataset")
15
+ model = MultiTaskBertModel(config, dataset)
16
+
17
+ model.load_state_dict(torch.load("pytorch_model.bin"))
18
+
19
+ model.eval()
20
+
21
+ def predict(input_data):
22
+
23
+ tok = tokenizer()
24
+ preprocessed_input = tok(input_data,
25
+ return_offsets_mapping=True,
26
+ padding='max_length',
27
+ truncation=True,
28
+ max_length=128)
29
+
30
+ input_ids = torch.tensor([preprocessed_input['input_ids']])
31
+ attention_mask = torch.tensor([preprocessed_input['attention_mask']])
32
+ offset_mapping = torch.tensor(preprocessed_input['offset_mapping'])
33
+
34
+ with torch.no_grad():
35
+
36
+ ner_logits, intent_logits = model.forward(input_ids, attention_mask)
37
+
38
+ ner_logits = torch.argmax(ner_logits.view(-1, 9), dim=1)
39
+ intent_logits = torch.argmax(intent_logits)
40
+
41
+ aligned_predictions = []
42
+
43
+ for prediction, (start, end) in zip(ner_logits, offset_mapping):
44
+ if start == end:
45
+ continue
46
+
47
+ word = input_data[start:end]
48
+
49
+ if not word.strip():
50
+ continue
51
+
52
+ aligned_predictions.append((word, int(prediction)))
53
+
54
+ labels = intent_labels_to_ids()
55
+ intent_labels = intent_ids_to_labels(labels)
56
+
57
+ print(f"Ner logits: {aligned_predictions}")
58
+ print(f"Intent logits: {intent_labels}")
59
+
60
+ title = "Multi Task Model"
61
+ description = '''
62
+ The model was trained to do NER and Intent classification for a scheduler
63
+ '''
64
+
65
+ gr.Interface(
66
+ fn=predict,
67
+ inputs="text",
68
+ outputs="text",
69
+ title=title,
70
+ description=description,
71
+ examples=[["Remind me about the meeting at 3 PM tomorrow"], ["Set a timer for 10 minutes"]],
72
+ ).launch(share=True)
data_loader.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, List, Union
4
+ import sys
5
+
6
+ project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ sys.path.append(project_dir)
8
+
9
+ from utils import structure_data
10
+
11
+
12
+ def load_dataset(dataset_name: str) -> Dict[str, Union[str, List[str]]]:
13
+ """
14
+ Load training dataset or validation dataset.
15
+
16
+ Args:
17
+ dataset_name (str): The name of the dataset. Should be either 'training_dataset' or 'validation_dataset'.
18
+
19
+ Returns:
20
+ dataset (Dict[str, Union[str. List[str]]]): A dictionary representing the
21
+ loaded dataset with keys 'text', 'ner', and 'intent'.
22
+
23
+ Raises:
24
+ ValueError: If the provided dataset_name is not one of the valid_names.
25
+ FileNotFoundError: If the dataset file is not found in the specified path.
26
+ """
27
+
28
+ valid_names = ["training_dataset", "validation_dataset"]
29
+
30
+ if dataset_name not in valid_names:
31
+ raise ValueError(f"Invalid dataset name. Expected one of {valid_names}, got {dataset_name}")
32
+
33
+ path = f"{dataset_name}.json"
34
+
35
+ if not os.path.exists(path):
36
+ raise FileNotFoundError(f"Dataset file not found at {path}")
37
+
38
+ with open(path, 'r') as f:
39
+ dataset = json.load(f)
40
+
41
+ dataset = structure_data(dataset)
42
+
43
+ return dataset
metrics.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import Metric
3
+
4
+ class MyAccuracy(Metric):
5
+ """
6
+ Accuracy metric costomized for handling sequences with padding.
7
+
8
+ Methods:
9
+ update(self, logits, labels, num_labels): Update the accuracy based on
10
+ model predictions and ground truth labels.
11
+
12
+ compute(self): Compute the accuracy.
13
+
14
+ Attributes:
15
+ total (torch.Tensor): Total number of non-padding elements.
16
+ correct (torch.Tensor): Number of correctly predicted non-padding elements.
17
+ """
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum')
21
+ self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum')
22
+
23
+ def update(self, logits: torch.Tensor, labels: torch.Tensor, num_labels: int) -> None:
24
+ """
25
+ Args:
26
+ logits (torch.Tensor): Model predictions.
27
+ labels (torch.Tensor): Ground truth labels.
28
+ num_labels (int): Number of unique labels.
29
+ """
30
+ flattened_targets = labels.view(-1) # shape (batch_size, sequence_len)
31
+ active_logits = logits.view(-1, num_labels) # shape (batch_size * sequence_len, num_labels)
32
+ flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * sequence_len)
33
+
34
+ # compute accuracy only at active labels
35
+ active_accuracy = labels.view(-1) != -100 # shape (batch_size, sequnce_len)
36
+ ac_labels = torch.masked_select(flattened_targets, active_accuracy)
37
+ predictions = torch.masked_select(flattened_predictions, active_accuracy)
38
+
39
+ self.correct += torch.sum(ac_labels == predictions)
40
+ self.total += torch.numel(ac_labels)
41
+
42
+ def compute(self) -> torch.Tensor:
43
+ """
44
+ Calculate the accuracy.
45
+ """
46
+ return self.correct.float() / self.total.float()
model.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel
2
+ import torch
3
+ import onnx
4
+ import pytorch_lightning as pl
5
+ import wandb
6
+ from metrics import MyAccuracy
7
+ from utils import num_unique_labels
8
+ from typing import Dict, Tuple, List, Optional
9
+
10
+ class MultiTaskBertModel(pl.LightningModule):
11
+
12
+ """
13
+ Multi-task Bert model for Named Entity Recognition (NER) and Intent Classification
14
+
15
+ Args:
16
+ config (BertConfig): Bert model configuration.
17
+ dataset (Dict[str, Union[str, List[str]]]): A dictionary containing keys 'text', 'ner', and 'intent'.
18
+ """
19
+
20
+ def __init__(self, config, dataset):
21
+ super().__init__()
22
+
23
+ self.num_ner_labels, self.num_intent_labels = num_unique_labels(dataset)
24
+
25
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
26
+
27
+ self.model = BertModel(config=config)
28
+
29
+ self.ner_classifier = torch.nn.Linear(config.hidden_size, self.num_ner_labels)
30
+ self.intent_classifier = torch.nn.Linear(config.hidden_size, self.num_intent_labels)
31
+
32
+ # log hyperparameters
33
+ self.save_hyperparameters()
34
+
35
+ self.accuracy = MyAccuracy()
36
+
37
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
38
+
39
+ """
40
+ Perform a forward pass through Multi-task Bert model.
41
+
42
+ Args:
43
+ input_ids (torch.Tensor, torch.shape: (batch, length_of_tokenized_sequences)): Input token IDs.
44
+ attention_mask (Optional[torch.Tensor]): Attention mask for input tokens.
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor,torch.Tensor]: NER logits, Intent logits.
48
+ """
49
+
50
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
51
+
52
+ sequence_output = outputs[0]
53
+ sequence_output = self.dropout(sequence_output)
54
+ ner_logits = self.ner_classifier(sequence_output)
55
+
56
+ pooled_output = outputs[1]
57
+ pooled_output = self.dropout(pooled_output)
58
+ intent_logits = self.intent_classifier(pooled_output)
59
+
60
+ return ner_logits, intent_logits
61
+
62
+ def training_step(self: pl.LightningModule, batch, batch_idx: int) -> torch.Tensor:
63
+ """
64
+ Perform a training step for the Multi-task BERT model.
65
+
66
+ Args:
67
+ batch: Input batch.
68
+ batch_idx (int): Index of the batch.
69
+
70
+ Returns:
71
+ torch.Tensor: Loss value
72
+ """
73
+ loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
74
+ accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
75
+ accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
76
+ self.log_dict({'training_loss': loss, 'ner_accuracy': accuracy_ner, 'intent_accuracy': accuracy_intent},
77
+ on_step=False, on_epoch=True, prog_bar=True)
78
+ return loss
79
+
80
+ def on_validation_epoch_start(self):
81
+ self.validation_step_outputs_ner = []
82
+ self.validation_step_outputs_intent = []
83
+
84
+ def validation_step(self, batch, batch_idx: int) -> torch.Tensor:
85
+ """
86
+ Perform a validation step for the Multi-task BERT model.
87
+
88
+ Args:
89
+ batch: Input batch.
90
+ batch_idx (int): Index of the batch.
91
+
92
+ Returns:
93
+ torch.Tensor: Loss value.
94
+ """
95
+ loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
96
+ # self.log('val_loss', loss)
97
+ accuracy_ner = self.accuracy(ner_logits, ner_labels, self.num_ner_labels)
98
+ accuracy_intent = self.accuracy(intent_logits, intent_labels, self.num_intent_labels)
99
+ self.log_dict({'validation_loss': loss, 'val_ner_accuracy': accuracy_ner, 'val_intent_accuracy': accuracy_intent},
100
+ on_step=False, on_epoch=True, prog_bar=True)
101
+
102
+ self.validation_step_outputs_ner.append(ner_logits)
103
+ self.validation_step_outputs_intent.append(intent_logits)
104
+ return loss
105
+
106
+ def on_validation_epoch_end(self):
107
+ """
108
+ Perform actions at the end of validation epoch to track the training process in WandB.
109
+ """
110
+ validation_step_outputs_ner = self.validation_step_outputs_ner
111
+ validation_step_outputs_intent = self.validation_step_outputs_intent
112
+
113
+ dummy_input = torch.zeros((1, 128), device=self.device, dtype=torch.long)
114
+ model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
115
+ torch.onnx.export(self, dummy_input, model_filename)
116
+ artifact = wandb.Artifact(name="model.ckpt", type="model")
117
+ artifact.add_file(model_filename)
118
+ self.logger.experiment.log_artifact(artifact)
119
+
120
+ flattened_logits_ner = torch.flatten(torch.cat(validation_step_outputs_ner))
121
+ flattened_logits_intent = torch.flatten(torch.cat(validation_step_outputs_intent))
122
+ self.logger.experiment.log(
123
+ {"valid/ner_logits": wandb.Histogram(flattened_logits_ner.to('cpu')),
124
+ "valid/intent_logits": wandb.Histogram(flattened_logits_intent.to('cpu')),
125
+ "global_step": self.global_step}
126
+ )
127
+
128
+ def _common_step(self, batch, batch_idx):
129
+ """
130
+ Common steps for both training and validation. Calculate loss for both NER and intent layer.
131
+
132
+ Returns:
133
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
134
+ Combiner loss value, NER logits, intent logits, NER labels, intent labels.
135
+ """
136
+ ids = batch['input_ids']
137
+ mask = batch['attention_mask']
138
+ ner_labels = batch['ner_labels']
139
+ intent_labels = batch['intent_labels']
140
+
141
+ ner_logits, intent_logits = self.forward(input_ids=ids, attention_mask=mask)
142
+
143
+ criterion = torch.nn.CrossEntropyLoss()
144
+
145
+ ner_loss = criterion(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1).long())
146
+ intent_loss = criterion(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1).long())
147
+
148
+ loss = ner_loss + intent_loss
149
+ return loss, ner_logits, intent_logits, ner_labels, intent_labels
150
+
151
+ def configure_optimizers(self):
152
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
153
+ return optimizer
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbd376379912824c97a8c347e155db8e458526183f8939c2c6b2b780ea8698cc
3
+ size 438053110
requirements.txt ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.1
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==3.7.1
7
+ appdirs==1.4.4
8
+ argon2-cffi==21.3.0
9
+ argon2-cffi-bindings==21.2.0
10
+ arrow==1.2.3
11
+ asttokens==2.2.1
12
+ async-lru==2.0.4
13
+ attrs==23.1.0
14
+ Babel==2.12.1
15
+ backcall==0.2.0
16
+ beautifulsoup4==4.12.2
17
+ bleach==6.0.0
18
+ blinker==1.6.2
19
+ certifi==2023.5.7
20
+ cffi==1.15.1
21
+ charset-normalizer==3.2.0
22
+ click==8.1.7
23
+ colorama==0.4.6
24
+ coloredlogs==15.0.1
25
+ comm==0.1.4
26
+ contourpy==1.1.0
27
+ cvxopt==1.3.2
28
+ cvxpy==1.3.2
29
+ cycler==0.11.0
30
+ debugpy==1.6.7
31
+ decorator==5.1.1
32
+ defusedxml==0.7.1
33
+ docker-pycreds==0.4.0
34
+ ecos==2.0.12
35
+ executing==1.2.0
36
+ fastapi==0.109.2
37
+ fastjsonschema==2.18.0
38
+ ffmpy==0.3.2
39
+ filelock==3.12.4
40
+ Flask==2.3.3
41
+ flatbuffers==23.5.26
42
+ fonttools==4.42.0
43
+ fqdn==1.5.1
44
+ frozenlist==1.4.1
45
+ fsspec==2023.9.2
46
+ gitdb==4.0.11
47
+ GitPython==3.1.41
48
+ gradio==4.19.1
49
+ gradio_client==0.10.0
50
+ h11==0.14.0
51
+ httpcore==1.0.3
52
+ httpx==0.26.0
53
+ huggingface-hub==0.20.3
54
+ humanfriendly==10.0
55
+ hypothesis==6.97.1
56
+ idna==3.4
57
+ importlib-resources==6.1.1
58
+ iniconfig==2.0.0
59
+ ipykernel==6.25.0
60
+ ipython==8.14.0
61
+ isoduration==20.11.0
62
+ itsdangerous==2.1.2
63
+ jedi==0.19.0
64
+ Jinja2==3.1.2
65
+ joblib==1.3.1
66
+ json5==0.9.14
67
+ jsonpointer==2.4
68
+ jsonschema==4.18.6
69
+ jsonschema-specifications==2023.7.1
70
+ jupyter-events==0.7.0
71
+ jupyter-lsp==2.2.0
72
+ jupyter_client==8.3.0
73
+ jupyter_core==5.3.1
74
+ jupyter_server==2.7.0
75
+ jupyter_server_terminals==0.4.4
76
+ jupyterlab==4.0.4
77
+ jupyterlab-pygments==0.2.2
78
+ jupyterlab_server==2.24.0
79
+ kiwisolver==1.4.4
80
+ lightning==2.1.3
81
+ lightning-utilities==0.10.1
82
+ lxml==4.9.3
83
+ markdown-it-py==3.0.0
84
+ MarkupSafe==2.1.3
85
+ matplotlib==3.7.2
86
+ matplotlib-inline==0.1.6
87
+ mdurl==0.1.2
88
+ mistune==3.0.1
89
+ mpmath==1.3.0
90
+ multidict==6.0.4
91
+ nbclient==0.8.0
92
+ nbconvert==7.7.3
93
+ nbformat==5.9.2
94
+ nest-asyncio==1.5.7
95
+ networkx==3.2.1
96
+ nnfs==0.5.1
97
+ notebook_shim==0.2.3
98
+ numpy==1.25.1
99
+ onnx==1.15.0
100
+ onnxruntime==1.17.0
101
+ orjson==3.9.14
102
+ osqp==0.6.3
103
+ overrides==7.4.0
104
+ packaging==23.1
105
+ pandas==2.0.3
106
+ pandocfilters==1.5.0
107
+ parso==0.8.3
108
+ pickleshare==0.7.5
109
+ Pillow==10.0.0
110
+ platformdirs==3.10.0
111
+ pluggy==1.4.0
112
+ praw==7.7.1
113
+ prawcore==2.4.0
114
+ prometheus-client==0.17.1
115
+ prompt-toolkit==3.0.39
116
+ protobuf==4.25.2
117
+ psutil==5.9.5
118
+ pure-eval==0.2.2
119
+ pyarrow==14.0.0
120
+ pycparser==2.21
121
+ pydantic==2.6.1
122
+ pydantic_core==2.16.2
123
+ pydub==0.25.1
124
+ pygame==2.5.0
125
+ Pygments==2.16.1
126
+ pyparsing==3.0.9
127
+ PyPDF2==3.0.1
128
+ pyreadline3==3.4.1
129
+ pytest==8.0.0
130
+ python-dateutil==2.8.2
131
+ python-docx==1.1.0
132
+ python-json-logger==2.0.7
133
+ python-multipart==0.0.9
134
+ pytorch-lightning==2.1.3
135
+ pytz==2023.3
136
+ pywin32==306
137
+ pywinpty==2.0.11
138
+ PyYAML==6.0.1
139
+ pyzmq==25.1.0
140
+ qdldl==0.1.7.post0
141
+ referencing==0.30.2
142
+ regex==2023.8.8
143
+ requests==2.31.0
144
+ rfc3339-validator==0.1.4
145
+ rfc3986-validator==0.1.1
146
+ rich==13.7.0
147
+ rpds-py==0.9.2
148
+ ruff==0.2.2
149
+ safetensors==0.3.3
150
+ scikit-learn==1.3.0
151
+ scipy==1.11.1
152
+ scs==3.2.3
153
+ seaborn==0.12.2
154
+ semantic-version==2.10.0
155
+ Send2Trash==1.8.2
156
+ sentry-sdk==1.39.2
157
+ setproctitle==1.3.3
158
+ shellingham==1.5.4
159
+ six==1.16.0
160
+ smmap==5.0.1
161
+ sniffio==1.3.0
162
+ sortedcontainers==2.4.0
163
+ soupsieve==2.4.1
164
+ stack-data==0.6.2
165
+ starlette==0.36.3
166
+ sympy==1.12
167
+ terminado==0.17.1
168
+ threadpoolctl==3.2.0
169
+ tinycss2==1.2.1
170
+ tokenizers==0.13.3
171
+ tomlkit==0.12.0
172
+ toolz==0.12.1
173
+ torch==2.1.2
174
+ torchaudio==2.1.2
175
+ torchmetrics==1.3.0.post0
176
+ torchvision==0.16.2
177
+ tornado==6.3.2
178
+ tqdm==4.66.1
179
+ traitlets==5.9.0
180
+ transformers==4.33.2
181
+ typer==0.9.0
182
+ typing_extensions==4.8.0
183
+ tzdata==2023.3
184
+ update-checker==0.18.0
185
+ uri-template==1.3.0
186
+ urllib3==2.0.4
187
+ uvicorn==0.27.1
188
+ wandb==0.16.2
189
+ wcwidth==0.2.6
190
+ webcolors==1.13
191
+ webencodings==0.5.1
192
+ websocket-client==1.6.1
193
+ websockets==11.0.3
194
+ Werkzeug==2.3.7
195
+ windows-curses==2.3.1
196
+ yarl==1.9.4
training_dataset.json ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "text": "Set a timer for 10 minutes.",
4
+ "intent": "'Set Timer'",
5
+ "entities": "O O O O B-DUR I-DUR"
6
+ },
7
+ {
8
+ "text": "Remind me about the meeting at 3 PM tomorrow.",
9
+ "intent": "'Set Reminder'",
10
+ "entities": "O O O O O O B-TIME I-TIME B-DATE"
11
+ },
12
+ {
13
+ "text": "Schedule an appointment for next Friday at 9 AM.",
14
+ "intent": "'Schedule Appointment'",
15
+ "entities": "O O O O B-DATE I-DATE O B-TIME I-TIME"
16
+ },
17
+ {
18
+ "text": "Can you set a reminder for my doctor's appointment on Monday?",
19
+ "intent": "'Set Reminder'",
20
+ "entities": "O O O O O O O O O O B-DATE"
21
+ },
22
+ {
23
+ "text": "I want to schedule a meeting for the 15th of this month at 2:30 PM.",
24
+ "intent": "'Schedule Meeting'",
25
+ "entities": "O O O O O O O O B-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
26
+ },
27
+ {
28
+ "text": "Set an alarm for 7 AM.",
29
+ "intent": "'Set Alarm'",
30
+ "entities": "O O O O B-TIME I-TIME"
31
+ },
32
+ {
33
+ "text": "Remind me to call John in 30 minutes.",
34
+ "intent": "'Set Reminder'",
35
+ "entities": "O O O B-TASK I-TASK O B-DUR I-DUR"
36
+ },
37
+ {
38
+ "text": "\"Schedule a meeting for next Wednesday afternoon.\"",
39
+ "intent": "'Schedule Meeting'",
40
+ "entities": "O O O O B-DATE I-DATE B-TIME"
41
+ },
42
+ {
43
+ "text": "Can you set a timer for cooking for 1 hour?",
44
+ "intent": "'Set Timer'",
45
+ "entities": "O O O O O O B-TASK O B-DUR I-DUR"
46
+ },
47
+ {
48
+ "text": "Remind me about the project deadline at 5 PM on Friday.",
49
+ "intent": "'Set Reminder'",
50
+ "entities": "O O O O B-TASK I-TASK O B-TIME I-TIME O B-DATE"
51
+ },
52
+ {
53
+ "text": "Schedule a doctor's appointment for March 20th at 10:30 AM.",
54
+ "intent": "'Schedule Appointment'",
55
+ "entities": "O O O O O B-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
56
+ },
57
+ {
58
+ "text": "Set a timer for a 15-minute break.",
59
+ "intent": "'Set Timer'",
60
+ "entities": "O O O O O B-DUR I-DUR I-DUR B-TASK"
61
+ },
62
+ {
63
+ "text": "Remind me to buy groceries tomorrow morning.",
64
+ "intent": "'Set Reminder'",
65
+ "entities": "O O O B-TASK I-TASK B-DATE B-TIME"
66
+ },
67
+ {
68
+ "text": "Schedule a conference call for the first Monday of next month at 3 PM.",
69
+ "intent": "'Schedule Meeting'",
70
+ "entities": "O O O O O O B-DATE I-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME"
71
+ },
72
+ {
73
+ "text": "Can you remind me to send the report at 4:30 PM today?",
74
+ "intent": "'Set Reminder'",
75
+ "entities": "O O O O O B-TASK I-TASK I-TASK O B-TIME I-TIME I-TIME I-TIME B-DATE"
76
+ },
77
+ {
78
+ "text": "Set a timer for a 20-minute workout session.",
79
+ "intent": "'Set Timer'",
80
+ "entities": "O O O O O B-DUR I-DUR I-DUR B-TASK I-TASK"
81
+ },
82
+ {
83
+ "text": "Remind me to water the plants every Tuesday and Thursday at 9 AM.",
84
+ "intent": "'Set Reminder'",
85
+ "entities": "O O O B-TASK I-TASK I-TASK O B-DATE O B-DATE O B-TIME I-TIME"
86
+ },
87
+ {
88
+ "text": "Schedule a team meeting for next Monday morning at 10:30.",
89
+ "intent": "'Schedule Meeting'",
90
+ "entities": "O O O O O B-DATE I-DATE B-TIME I-TIME I-TIME I-TIME I-TIME"
91
+ },
92
+ {
93
+ "text": "Can you set an alarm for 6:45 AM?",
94
+ "intent": "'Set Alarm'",
95
+ "entities": "O O O O O O B-TIME I-TIME I-TIME I-TIME"
96
+ },
97
+ {
98
+ "text": "Remind me about the webinar in 2 days at 2 PM.",
99
+ "intent": "'Set Reminder'",
100
+ "entities": "O O O O B-TASK O B-DUR I-DUR O B-TIME I-TIME"
101
+ },
102
+ {
103
+ "text": "Schedule a dentist appointment for April 5th at 11:00 in the morning.",
104
+ "intent": "'Schedule Appointment'",
105
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME I-TIME I-TIME"
106
+ },
107
+ {
108
+ "text": "Set a timer for a 5-minute meditation session.",
109
+ "intent": "'Set Timer'",
110
+ "entities": "O O O O O B-DUR I-DUR I-DUR B-TASK I-TASK"
111
+ },
112
+ {
113
+ "text": "Remind me to call Sarah next Wednesday afternoon.",
114
+ "intent": "'Set Reminder'",
115
+ "entities": "O O O B-TASK I-TASK B-DATE I-DATE B-TIME"
116
+ },
117
+ {
118
+ "text": "Schedule a review meeting for the end of the month at 4:30 PM.",
119
+ "intent": "'Schedule Meeting'",
120
+ "entities": "O O B-TASK I-TASK O O B-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME I-TIME I-TIME"
121
+ },
122
+ {
123
+ "text": "Can you remind me to pay bills on the last day of the month?",
124
+ "intent": "'Set Reminder'",
125
+ "entities": "O O O O O B-TASK I-TASK O O B-DATE I-DATE I-DATE I-DATE I-DATE"
126
+ },
127
+ {
128
+ "text": "Set a timer for 45 minutes for a study session.",
129
+ "intent": "'Set Timer'",
130
+ "entities": "O O O O B-DUR I-DUR O O B-TASK I-TASK"
131
+ },
132
+ {
133
+ "text": "Remind me to pick up the laundry every Friday afternoon.",
134
+ "intent": "'Set Reminder'",
135
+ "entities": "O O O B-TASK I-TASK I-TASK I-TASK O B-DATE B-TIME"
136
+ },
137
+ {
138
+ "text": "Schedule a client meeting for the 10th of next month at 2 PM.",
139
+ "intent": "'Schedule Meeting'",
140
+ "entities": "O O B-TASK I-TASK O O B-DATE I-DATE I-DATE I-DATE O B-TIME I-TIME"
141
+ },
142
+ {
143
+ "text": "Can you set an alarm for 7:30 AM tomorrow?",
144
+ "intent": "'Set Alarm'",
145
+ "entities": "O O O O O O B-TIME I-TIME I-TIME I-TIME B-DATE"
146
+ },
147
+ {
148
+ "text": "Remind me about the presentation at 4 PM today.",
149
+ "intent": "'Set Reminder'",
150
+ "entities": "O O O O B-TASK O B-TIME I-TIME B-DATE"
151
+ },
152
+ {
153
+ "text": "Schedule a doctor's appointment for May 15th in the evening.",
154
+ "intent": "'Schedule Appointment'",
155
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O O B-TIME"
156
+ },
157
+ {
158
+ "text": "Set a timer for a 10-minute break between study sessions.",
159
+ "intent": "'Set Timer'",
160
+ "entities": "O O O O O B-DUR I-DUR I-DUR O O O O"
161
+ },
162
+ {
163
+ "text": "Remind me to send the report at 9 AM tomorrow.",
164
+ "intent": "'Set Reminder'",
165
+ "entities": "O O O B-TASK I-TASK I-TASK O B-TIME I-TIME B-DATE"
166
+ },
167
+ {
168
+ "text": "Schedule a team lunch for next Friday at noon.",
169
+ "intent": "'Schedule Meeting'",
170
+ "entities": "O O B-TASK I-TASK O B-DATE I-DATE O B-TIME"
171
+ },
172
+ {
173
+ "text": "Can you remind me to buy groceries on Saturday afternoon?",
174
+ "intent": "'Set Reminder'",
175
+ "entities": "O O O O O B-TASK I-TASK O B-DATE B-TIME"
176
+ }
177
+ ]
utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizerFast, BertConfig
2
+ from typing import Dict, List, Union, Tuple
3
+
4
+
5
+ def num_unique_labels(dataset: Dict[str, Union[str, List[str]]]) -> Tuple[int, int]:
6
+ """
7
+ Calculate the number of NER labels and INTENT labels in the dataset.
8
+
9
+ Args:
10
+ dataset (dict): A dictionary containing 'text', 'entities' and 'intent' keys.
11
+
12
+ Returns:
13
+ Tuple: Number of unique NER and INTENT lables.
14
+ """
15
+ one_dimensional_ner = [tag for subset in dataset['entities'] for tag in subset]
16
+ return len(set(one_dimensional_ner)), len(set(dataset['intent']))
17
+
18
+ def ner_labels_to_ids() -> Dict[str, int]:
19
+ """
20
+ Map NER labels to corresponding numeric IDs.
21
+
22
+ Returns:
23
+ Dict[str, int]: A dictionary where keys are NER labels, and values are their corresponding IDs.
24
+ """
25
+ labels_to_ids_ner = {
26
+ 'O': 0,
27
+ 'B-DATE': 1,
28
+ 'I-DATE': 2,
29
+ 'B-TIME': 3,
30
+ 'I-TIME': 4,
31
+ 'B-TASK': 5,
32
+ 'I-TASK': 6,
33
+ 'B-DUR': 7,
34
+ 'I-DUR': 8
35
+ }
36
+ return labels_to_ids_ner
37
+
38
+ def ner_ids_to_labels(ner_labels_to_ids) -> Dict[int, str]:
39
+ """
40
+ Map numeric IDs to corresponding NER labels.
41
+
42
+ Returns:
43
+ Dict[int, str]: A dictionary where keys are numeric IDs, and values are their corresponding NER labels.
44
+ """
45
+ ner_ids_to_labels = {v: k for k, v in ner_labels_to_ids.items()}
46
+ return ner_ids_to_labels
47
+
48
+ def intent_labels_to_ids() -> Dict[str, int]:
49
+ """
50
+ Map intent labels to corresponding numeric values.
51
+
52
+ Returns:
53
+ Dict[str, int]: A dictionary where keys are intent labels, and values are their corresponding numeric IDs.
54
+ """
55
+ intent_labels_to_ids = {
56
+ "'Schedule Appointment'": 0,
57
+ "'Schedule Meeting'": 1,
58
+ "'Set Alarm'": 2,
59
+ "'Set Reminder'": 3,
60
+ "'Set Timer'": 4
61
+ }
62
+ return intent_labels_to_ids
63
+
64
+ def intent_ids_to_labels(intent_labels_to_ids) -> Dict[int, str]:
65
+ """
66
+ Map numeric values to corresponding intent labels.
67
+
68
+ Returns:
69
+ Dict[int, str]: A dictionary where keys are numeric IDs, and values are their corresponding intent labels.
70
+ """
71
+ intent_ids_to_labels = {v: k for k, v in intent_labels_to_ids.items()}
72
+ return intent_ids_to_labels
73
+
74
+ def tokenizer() -> BertTokenizerFast:
75
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
76
+ return tokenizer
77
+
78
+ def bert_config() -> BertConfig:
79
+ config = BertConfig.from_pretrained('bert-base-uncased')
80
+ return config
81
+
82
+ def structure_data(dataset):
83
+ structured_data = {'text': [], 'entities': [], 'intent': []}
84
+ for sample in dataset:
85
+ structured_data['text'].append(sample['text'])
86
+ structured_data['entities'].append(sample['entities'].split())
87
+ structured_data['intent'].append(sample['intent'])
88
+ return structured_data