ChatData / callbacks /arxiv_callbacks.py
Fangrui Liu
Add text 2 sql query & ask
9061790
raw
history blame
No virus
3.53 kB
import streamlit as st
from typing import Dict, Any
from sql_formatter.core import format_sql
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text="Working...")
self.tokens_stream = ""
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_text(self, text: str, **kwargs) -> None:
self.progress_bar.progress(value=0.2, text="Asking LLM...")
def on_chain_end(self, outputs, **kwargs) -> None:
self.progress_bar.progress(value=0.6, text='Searching in DB...')
st.markdown('### Generated Filter')
st.write(outputs['text'], unsafe_allow_html=True)
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
pass
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text='Searching DB...')
self.status_bar = st.empty()
self.prog_value = 0.0
self.prog_map = {
'langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain': 0.2,
'langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain': 0.4,
'langchain.chains.combine_documents.stuff.StuffDocumentsChain': 0.8
}
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_text(self, text: str, **kwargs) -> None:
pass
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = '.'.join(serialized['id'])
if cid != 'langchain.chains.llm.LLMChain':
self.progress_bar.progress(value=self.prog_map[cid], text=f'Running Chain `{cid}`...')
self.prog_value = self.prog_map[cid]
else:
self.prog_value += 0.1
self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
def on_chain_end(self, outputs, **kwargs) -> None:
pass
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.2
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_text(self, text: str, **kwargs) -> None:
if text.startswith('SELECT'):
st.write('We generated Vector SQL for you:')
st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
print(f"Vector SQL: {text}")
self.prog_value += self.prog_interval
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = '.'.join(serialized['id'])
self.prog_value += self.prog_interval
self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
def on_chain_end(self, outputs, **kwargs) -> None:
pass
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.1