File size: 7,188 Bytes
46b59f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab6d9a9
46b59f3
 
 
 
 
 
 
 
 
 
 
8da1878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b59f3
ab6d9a9
46b59f3
 
 
 
 
 
 
91224d2
46b59f3
 
acfc725
46b59f3
 
c7e36a5
46b59f3
 
 
 
 
 
 
 
 
8da1878
46b59f3
8da1878
 
 
 
 
46b59f3
 
 
 
 
 
 
 
 
 
8da1878
46b59f3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from transformers import pipeline
import nltk
nltk.download('punkt')
from nltk import sent_tokenize
import torch
from transformers import (
    pipeline,
    AutoModelForSeq2SeqLM,
    AutoTokenizer
)

import re

device = [0 if torch.cuda.is_available() else 'cpu'][0]

def _generate(query, context, model, device):
    
    FT_MODEL = AutoModelForSeq2SeqLM.from_pretrained(model).to(device)
    FT_MODEL_TOKENIZER = AutoTokenizer.from_pretrained(model)
    input_text = "question: " + query + "</s> question_context: " + context
    
    input_tokenized = FT_MODEL_TOKENIZER.encode(input_text, return_tensors='pt', truncation=True, padding='max_length', max_length=1024).to(device)
    _tok_count_assessment = FT_MODEL_TOKENIZER.encode(input_text, return_tensors='pt', truncation=True).to(device)

    summary_ids = FT_MODEL.generate(input_tokenized, 
                                       max_length=30, 
                                       min_length=3, 
                                       length_penalty=1.0, 
                                       num_beams=2,
                                       early_stopping=True,
                                   )
    output = [FT_MODEL_TOKENIZER.decode(id, clean_up_tokenization_spaces=True, skip_special_tokens=True) for id in summary_ids] 
    
    return str(output[0])
    
def predict(query, context):
    
    context = context.encode("ascii", "ignore")
    context = context.decode()

    #Custom1
    cust_model_name = "consciousAI/question-answering-roberta-base-s"
    cust_question_answerer = pipeline('question-answering', model=cust_model_name, tokenizer=cust_model_name, device=device)
    
    cust_output = cust_question_answerer(question=query, context=context)
    cust_answer = cust_output['answer']
    cust_answer_span = "[" + str(cust_output['start']) + "," + str(cust_output['end']) + "]" 
    cust_confidence = cust_output['score']
    cust_answer_sentence = [_sent for _sent in sent_tokenize(context) if cust_answer in _sent]
    if len(cust_answer_sentence) > 0:
        cust_answer_sentence = cust_answer_sentence[0] 
    else:
        cust_answer_sentence = "Failed matching sentence (answer may be split in multiple sentences)"
    
    #Custom3
    cust_model_name_3 = "consciousAI/question-answering-roberta-base-s-v2"
    cust_question_answerer_3 = pipeline('question-answering', model=cust_model_name_3, tokenizer=cust_model_name_3, device=device)
    
    cust_output_3 = cust_question_answerer_3(question=query, context=context)
    cust_answer_3 = cust_output_3['answer']
    cust_answer_span_3 = "[" + str(cust_output_3['start']) + "," + str(cust_output_3['end']) + "]" 
    cust_confidence_3 = cust_output_3['score']
    cust_answer_sentence_3 = [_sent for _sent in sent_tokenize(context) if cust_answer_3 in _sent]
    if len(cust_answer_sentence_3) > 0:
        cust_answer_sentence_3 = cust_answer_sentence_3[0] 
    else:
        cust_answer_sentence_3 = "Failed matching sentence (answer may be split in multiple sentences)"
               
    #Custom2
    cust_answer_2 = _generate(query, context, model="consciousAI/question-answering-generative-t5-v1-base-s-q-c", device=device)
    cust_answer_sentence_2 = [_sent for _sent in sent_tokenize(context) if cust_answer_2 in _sent]
    if len(cust_answer_sentence_2) > 0:
        cust_answer_sentence_2 = cust_answer_sentence_2[0]
    else:
        cust_answer_sentence_2 = "Failed matching sentence (answer may be split in multiple sentences)"
    cust_answer_span_2 = re.search(cust_answer_2, contextDefault).span()
    
    return cust_answer, cust_answer_sentence, cust_answer_span, cust_confidence, cust_answer_2, cust_answer_sentence_2, cust_answer_span_2, cust_answer_sentence_3, cust_answer_3, cust_answer_span_3, cust_confidence_3
    
with gr.Blocks() as demo:
    gr.Markdown(value="# Question Answering Encoders vs Generative\n [Question Answering Leveraging Encoders V1](https://huggingface.co/anshoomehra/question-answering-roberta-base-s)\n\n[Question Answering Leveraging Encoders V2](https://huggingface.co/anshoomehra/question-answering-roberta-base-s-v2)\n\n[Generative Question Answering](https://huggingface.co/anshoomehra/question-answering-generative-t5-v1-base-s-q-c)")
    with gr.Accordion(variant='compact', label='Input Values'):
        with gr.Row(variant='compact'):
            queryDefault = "Which company alongside Amazon, Apple, Meta, and Microsoft is considered part of Big Five?"
            contextDefault = "Google LLC is an American multinational technology company focusing on search engine technology, online advertising, cloud computing, computer software, quantum computing, e-commerce, artificial intelligence, and consumer electronics. It has been referred to as 'the most powerful company in the world' and one of the world's most valuable brands due to its market dominance, data collection, and technological advantages in the area of artificial intelligence. Its parent company Alphabet is considered one of the Big Five American information technology companies, alongside Amazon, Apple, Meta, and Microsoft."
            query = gr.Textbox(queryDefault, label="Query", placeholder="Dummy Query", lines=2)
            context = gr.Textbox(contextDefault, label="Context", placeholder="Dummy Context", lines=5, max_lines = 6)

    with gr.Accordion(variant='compact', label='Q&A Model(s) Output'):
        with gr.Row(variant='compact'):
            with gr.Column(variant='compact'):
                _predictionM6 = gr.Textbox(label="question-answering-roberta-base-s: Answer Sentence")
                _predictionM5 = gr.Textbox(label="question-answering-roberta-base-s: Answer")
                _predictionM7 = gr.Textbox(label="question-answering-roberta-base-s: Q&A Answer Span")
                _predictionM8 = gr.Textbox(label="question-answering-roberta-base-s: Answer Confidence")
            with gr.Column(variant='compact'):
                _predictionM12 = gr.Textbox(label="question-answering-roberta-base-s-v2: Answer Sentence")
                _predictionM13 = gr.Textbox(label="question-answering-roberta-base-s-v2: Answer")
                _predictionM14 = gr.Textbox(label="question-answering-roberta-base-s-v2: Q&A Answer Span")
                _predictionM15 = gr.Textbox(label="question-answering-roberta-base-s-v2: Answer Confidence")
            with gr.Column(variant='compact'):   
                _predictionM10 = gr.Textbox(label="question-answering-generative-t5-v1-base-s-q-c: Sentence")
                _predictionM9 = gr.Textbox(label="question-answering-generative-t5-v1-base-s-q-c: Answer")
                _predictionM11 = gr.Textbox(label="question-answering-generative-t5-v1-base-s-q-c: Answer Span")
                

    with gr.Row():       
        gen_btn = gr.Button("Generate Answers")
        gen_btn.click(fn=predict,
                      inputs=[query, context],
                      outputs=[_predictionM5, _predictionM6, _predictionM7, _predictionM8, _predictionM9, _predictionM10, _predictionM11, _predictionM12, _predictionM13, _predictionM14, _predictionM15]
                      )

demo.launch(show_error=True)