File size: 2,076 Bytes
30e1793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import os
import sys
import gradio as gr

project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_dir)

from model import MultiTaskBertModel
from data_loader import load_dataset
from utils import bert_config, tokenizer, intent_ids_to_labels, intent_labels_to_ids

config = bert_config()
dataset = load_dataset("training_dataset")
model = MultiTaskBertModel(config, dataset)

model.load_state_dict(torch.load("pytorch_model.bin"))

model.eval()

def predict(input_data):

    tok = tokenizer()
    preprocessed_input = tok(input_data,
                                   return_offsets_mapping=True,
                                   padding='max_length',
                                   truncation=True,
                                   max_length=128)

    input_ids = torch.tensor([preprocessed_input['input_ids']])
    attention_mask = torch.tensor([preprocessed_input['attention_mask']])
    offset_mapping = torch.tensor(preprocessed_input['offset_mapping'])
    
    with torch.no_grad():

        ner_logits, intent_logits = model.forward(input_ids, attention_mask)

    ner_logits = torch.argmax(ner_logits.view(-1, 9), dim=1)
    intent_logits = torch.argmax(intent_logits)

    aligned_predictions = []

    for prediction, (start, end) in zip(ner_logits, offset_mapping):
        if start == end:
            continue

        word = input_data[start:end]

        if not word.strip():
            continue

        aligned_predictions.append((word, int(prediction)))

    labels = intent_labels_to_ids()
    intent_labels = intent_ids_to_labels(labels)

    print(f"Ner logits: {aligned_predictions}")
    print(f"Intent logits: {intent_labels}")

title = "Multi Task Model"
description = '''
The model was trained to do NER and Intent classification for a scheduler
'''

gr.Interface(
    fn=predict,
    inputs="text",
    outputs="text",
    title=title,
    description=description,
    examples=[["Remind me about the meeting at 3 PM tomorrow"], ["Set a timer for 10 minutes"]],
).launch(share=True)