File size: 5,744 Bytes
45180a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a796108
 
 
 
9061790
19bd5a9
abcac4c
a796108
45180a0
e4853cf
45180a0
a796108
 
45180a0
 
 
 
 
9061790
 
45180a0
9061790
 
 
 
a796108
9061790
 
 
 
 
 
45180a0
9061790
 
 
 
 
 
 
eb820e1
9061790
 
 
 
 
 
 
45180a0
9061790
 
 
 
eb820e1
45180a0
 
 
 
9061790
 
eb820e1
9061790
 
 
 
45180a0
9061790
 
 
 
 
 
 
 
 
 
 
 
45180a0
9061790
45180a0
9061790
 
 
45180a0
9061790
 
eb820e1
9061790
 
 
 
 
 
 
 
45180a0
9061790
 
 
 
eb820e1
45180a0
 
 
 
9061790
 
eb820e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
    ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
    ChatDataSQLAskCallBackHandler
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
from langchain.utilities.sql_database import SQLDatabase
from langchain.chains import LLMChain
from sqlalchemy import create_engine, MetaData
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
    SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain import OpenAI
import re
import pandas as pd
from os import environ
import streamlit as st
import datetime
from helper import build_all, sel_map, display
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']

st.set_page_config(page_title="ChatData")

st.header("ChatData")

if 'retriever' not in st.session_state:
    st.session_state["sel_map_obj"] = build_all()

sel = st.selectbox('Choose the knowledge base you want to ask with:',
                   options=['ArXiv Papers', 'Wikipedia'])
sel_map[sel]['hint']()
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
with tab_sql:
    sel_map[sel]['hint_sql']()
    st.text_input("Ask a question:", key='query_sql')
    cols = st.columns([1, 1, 7])
    cols[0].button("Query", key='search_sql')
    cols[1].button("Ask", key='ask_sql')
    plc_hldr = st.empty()
    if st.session_state.search_sql:
        plc_hldr = st.empty()
        print(st.session_state.query_sql)
        with plc_hldr.expander('Query Log', expanded=True):
            callback = ChatDataSQLSearchCallBackHandler()
            try:
                docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
                    st.session_state.query_sql, callbacks=[callback])
                callback.progress_bar.progress(value=1.0, text="Done!")
                docs = pd.DataFrame(
                    [{**d.metadata, 'abstract': d.page_content} for d in docs])
                display(docs)
            except Exception as e:
                st.write('Oops 😡 Something bad happened...')
                raise e

    if st.session_state.ask_sql:
        plc_hldr = st.empty()
        print(st.session_state.query_sql)
        with plc_hldr.expander('Chat Log', expanded=True):
            callback = ChatDataSQLAskCallBackHandler()
            try:
                ret = st.session_state.sel_map_obj[sel]["sql_chain"](
                    st.session_state.query_sql, callbacks=[callback])
                callback.progress_bar.progress(value=1.0, text="Done!")
                st.markdown(
                    f"### Answer from LLM\n{ret['answer']}\n### References")
                docs = ret['sources']
                docs = pd.DataFrame(
                    [{**d.metadata, 'abstract': d.page_content} for d in docs])
                display(
                    docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
            except Exception as e:
                st.write('Oops 😡 Something bad happened...')
                raise e


with tab_self_query:
    st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
    st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
    st.text_input("Ask a question:", key='query_self')
    cols = st.columns([1, 1, 7])
    cols[0].button("Query", key='search_self')
    cols[1].button("Ask", key='ask_self')
    plc_hldr = st.empty()
    if st.session_state.search_self:
        plc_hldr = st.empty()
        print(st.session_state.query_self)
        with plc_hldr.expander('Query Log', expanded=True):
            call_back = None
            callback = ChatDataSelfSearchCallBackHandler()
            try:
                docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
                    st.session_state.query_self, callbacks=[callback])
                print(docs)
                callback.progress_bar.progress(value=1.0, text="Done!")
                docs = pd.DataFrame(
                    [{**d.metadata, 'abstract': d.page_content} for d in docs])
                display(docs, sel_map[sel]["must_have_cols"])
            except Exception as e:
                st.write('Oops 😡 Something bad happened...')
                raise e

    if st.session_state.ask_self:
        plc_hldr = st.empty()
        print(st.session_state.query_self)
        with plc_hldr.expander('Chat Log', expanded=True):
            call_back = None
            callback = ChatDataSelfAskCallBackHandler()
            try:
                ret = st.session_state.sel_map_obj[sel]["chain"](
                    st.session_state.query_self, callbacks=[callback])
                callback.progress_bar.progress(value=1.0, text="Done!")
                st.markdown(
                    f"### Answer from LLM\n{ret['answer']}\n### References")
                docs = ret['sources']
                docs = pd.DataFrame(
                    [{**d.metadata, 'abstract': d.page_content} for d in docs])
                display(
                    docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
            except Exception as e:
                st.write('Oops 😡 Something bad happened...')
                raise e