multi_task_bert / app.py
kowalsky's picture
updated
d79be0d
raw
history blame contribute delete
No virus
2.69 kB
import torch
import os
import sys
import gradio as gr
project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_dir)
from model import MultiTaskBertModel
from data_loader import load_dataset
from utils import bert_config, tokenizer, intent_ids_to_labels, intent_labels_to_ids, ner_labels_to_ids, ner_ids_to_labels
config = bert_config()
dataset = load_dataset("training_dataset")
model = MultiTaskBertModel(config, dataset)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
ner_label_to_id = ner_labels_to_ids()
ner_id_to_label = ner_ids_to_labels(ner_label_to_id)
def predict(input_data):
tok = tokenizer()
preprocessed_input = tok(input_data,
return_offsets_mapping=True,
padding='max_length',
truncation=True,
max_length=128)
input_ids = torch.tensor([preprocessed_input['input_ids']])
attention_mask = torch.tensor([preprocessed_input['attention_mask']])
offset_mapping = torch.tensor(preprocessed_input['offset_mapping'])
with torch.no_grad():
ner_logits, intent_logits = model.forward(input_ids, attention_mask)
ner_logits = torch.argmax(ner_logits.view(-1, 9), dim=1)
intent_logits = torch.argmax(intent_logits)
aligned_predictions = []
for prediction, (start, end) in zip(ner_logits, offset_mapping):
if start == end:
continue
word = input_data[start:end]
if not word.strip():
continue
aligned_predictions.append((word, ner_id_to_label[int(prediction)]))
labels = intent_labels_to_ids()
intent_labels = intent_ids_to_labels(labels)
intent_labels = intent_labels[int(intent_logits)]
return f"Ner logits: {aligned_predictions}, Intent Label: {intent_labels}"
title = "Multi Task Model"
description = '''
This model is designed for a scheduler application, capable of handling various tasks such as setting
timers, scheduling meetings, appointments, and alarms. It provides Named Entity Recognition (NER) labels
to identify specific entities within the input text, along with an Intent label to determine the
overall task intention. The model's outputs facilitate efficient task management and organization,
enabling seamless interaction with the scheduler application.
<img src="bart.jpg" width=300px>
'''
gr.Interface(
fn=predict,
inputs="text",
outputs="text",
title=title,
description=description,
examples=[["Remind me about the meeting at 3 PM tomorrow"], ["Set a timer for 10 minutes"]],
).launch(share=True)