import torch import os import sys import gradio as gr project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_dir) from model import MultiTaskBertModel from data_loader import load_dataset from utils import bert_config, tokenizer, intent_ids_to_labels, intent_labels_to_ids, ner_labels_to_ids, ner_ids_to_labels config = bert_config() dataset = load_dataset("training_dataset") model = MultiTaskBertModel(config, dataset) model.load_state_dict(torch.load("pytorch_model.bin")) model.eval() ner_label_to_id = ner_labels_to_ids() ner_id_to_label = ner_ids_to_labels(ner_label_to_id) def predict(input_data): tok = tokenizer() preprocessed_input = tok(input_data, return_offsets_mapping=True, padding='max_length', truncation=True, max_length=128) input_ids = torch.tensor([preprocessed_input['input_ids']]) attention_mask = torch.tensor([preprocessed_input['attention_mask']]) offset_mapping = torch.tensor(preprocessed_input['offset_mapping']) with torch.no_grad(): ner_logits, intent_logits = model.forward(input_ids, attention_mask) ner_logits = torch.argmax(ner_logits.view(-1, 9), dim=1) intent_logits = torch.argmax(intent_logits) aligned_predictions = [] for prediction, (start, end) in zip(ner_logits, offset_mapping): if start == end: continue word = input_data[start:end] if not word.strip(): continue aligned_predictions.append((word, ner_id_to_label[int(prediction)])) labels = intent_labels_to_ids() intent_labels = intent_ids_to_labels(labels) intent_labels = intent_labels[int(intent_logits)] return f"Ner logits: {aligned_predictions}, Intent Label: {intent_labels}" title = "Multi Task Model" description = ''' This model is designed for a scheduler application, capable of handling various tasks such as setting timers, scheduling meetings, appointments, and alarms. It provides Named Entity Recognition (NER) labels to identify specific entities within the input text, along with an Intent label to determine the overall task intention. The model's outputs facilitate efficient task management and organization, enabling seamless interaction with the scheduler application. ''' gr.Interface( fn=predict, inputs="text", outputs="text", title=title, description=description, examples=[["Remind me about the meeting at 3 PM tomorrow"], ["Set a timer for 10 minutes"]], ).launch(share=True)