Anshoo Mehra commited on
Commit
46b59f3
1 Parent(s): 374a72c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import nltk
4
+ nltk.download('punkt')
5
+ from nltk import sent_tokenize
6
+ import torch
7
+ from transformers import (
8
+ pipeline,
9
+ AutoModelForSeq2SeqLM,
10
+ AutoTokenizer
11
+ )
12
+
13
+ import re
14
+
15
+ device = [0 if torch.cuda.is_available() else 'cpu'][0]
16
+
17
+ def _generate(query, context, model, device):
18
+
19
+ FT_MODEL = AutoModelForSeq2SeqLM.from_pretrained(model).to(device)
20
+ FT_MODEL_TOKENIZER = AutoTokenizer.from_pretrained(model)
21
+ input_text = "question: " + query + "</s> question_context: " + context
22
+
23
+ input_tokenized = FT_MODEL_TOKENIZER.encode(input_text, return_tensors='pt', truncation=True, padding='max_length', max_length=1024).to(device)
24
+ _tok_count_assessment = FT_MODEL_TOKENIZER.encode(input_text, return_tensors='pt', truncation=True).to(device)
25
+
26
+ summary_ids = FT_MODEL.generate(input_tokenized,
27
+ max_length=30,
28
+ min_length=3,
29
+ length_penalty=1.0,
30
+ num_beams=2,
31
+ early_stopping=True,
32
+ )
33
+ output = [FT_MODEL_TOKENIZER.decode(id, clean_up_tokenization_spaces=True, skip_special_tokens=True) for id in summary_ids]
34
+
35
+ return str(output[0])
36
+
37
+ def predict(query, context):
38
+
39
+ context = context.encode("ascii", "ignore")
40
+ context = context.decode()
41
+
42
+ #Custom1
43
+ cust_model_name = "anshoomehra/question-answering-roberta-base-s"
44
+ cust_question_answerer = pipeline('question-answering', model=cust_model_name, tokenizer=cust_model_name, device=device)
45
+
46
+ cust_output = cust_question_answerer(question=query, context=context)
47
+ cust_answer = cust_output['answer']
48
+ cust_answer_span = "[" + str(cust_output['start']) + "," + str(cust_output['end']) + "]"
49
+ cust_confidence = cust_output['score']
50
+ cust_answer_sentence = [_sent for _sent in sent_tokenize(context) if cust_answer in _sent]
51
+ if len(cust_answer_sentence) > 0:
52
+ cust_answer_sentence = cust_answer_sentence[0]
53
+ else:
54
+ cust_answer_sentence = "Failed matching sentence (answer may be split in multiple sentences)"
55
+
56
+ #Custom2
57
+ cust_answer_2 = _generate(query, context, model="anshoomehra/question-answering-generative-t5-v1-base-s-q-c", device=device)
58
+ cust_answer_sentence_2 = [_sent for _sent in sent_tokenize(context) if cust_answer_2 in _sent]
59
+ if len(cust_answer_sentence_2) > 0:
60
+ cust_answer_sentence_2 = cust_answer_sentence_2[0]
61
+ else:
62
+ cust_answer_sentence_2 = "Failed matching sentence (answer may be split in multiple sentences)"
63
+ cust_answer_span_2 = re.search(cust_answer_2, contextDefault).span()
64
+
65
+ return cust_answer, cust_answer_sentence, cust_answer_span, cust_confidence, cust_answer_2, cust_answer_sentence_2, cust_answer_span_2
66
+
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown(value="# Question Answering Demo")
69
+
70
+ with gr.Accordion(variant='compact', label='Input Values'):
71
+ with gr.Row(variant='compact'):
72
+ queryDefault = "What are the Big Five American information technology companies?"
73
+ contextDefault = "Google LLC is an American multinational technology company focusing on search engine technology, online advertising, cloud computing, computer software, quantum computing, e-commerce, artificial intelligence, and consumer electronics. It has been referred to as 'the most powerful company in the world' and one of the world's most valuable brands due to its market dominance, data collection, and technological advantages in the area of artificial intelligence. Its parent company Alphabet is considered one of the Big Five American information technology companies, alongside Amazon, Apple, Meta, and Microsoft."
74
+ query = gr.Textbox(queryDefault, label="Query", placeholder="Dummy Query", lines=2)
75
+ context = gr.Textbox(contextDefault, label="Context", placeholder="Dummy Context", lines=5, max_lines = 6)
76
+
77
+ with gr.Accordion(variant='compact', label='Q&A Model(s) Output'):
78
+ with gr.Row(variant='compact'):
79
+ with gr.Column(variant='compact'):
80
+ _predictionM6 = gr.Textbox(label="question-answering-roberta-base-s: Answer Sentence")
81
+ _predictionM5 = gr.Textbox(label="question-answering-roberta-base-s: Answer")
82
+ _predictionM7 = gr.Textbox(label="question-answering-roberta-base-s:Cisco Q&A Answer Span")
83
+ _predictionM8 = gr.Textbox(label="question-answering-roberta-base-s: Answer Confidence")
84
+ with gr.Column(variant='compact'):
85
+ _predictionM10 = gr.Textbox(label="question-answering-generative-t5-v1-base-s-q-c: Sentence")
86
+ _predictionM9 = gr.Textbox(label="question-answering-generative-t5-v1-base-s-q-c: Answer")
87
+ _predictionM11 = gr.Textbox(label="question-answering-generative-t5-v1-base-s-q-c: Answer Span")
88
+
89
+
90
+ with gr.Row():
91
+ gen_btn = gr.Button("Generate Answers")
92
+ gen_btn.click(fn=predict,
93
+ inputs=[query, context],
94
+ outputs=[_predictionM5, _predictionM6, _predictionM7, _predictionM8, _predictionM9, _predictionM10, _predictionM11]
95
+ )
96
+
97
+ demo.launch(show_error=True)