Fangrui Liu commited on
Commit
d5a4cb4
β€’
1 Parent(s): 526644e

fix callback

Browse files
Files changed (1) hide show
  1. callbacks/arxiv_callbacks.py +10 -3
callbacks/arxiv_callbacks.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  from typing import Dict, Any
3
  from sql_formatter.core import format_sql
4
  from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
 
5
 
6
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
7
  def __init__(self) -> None:
@@ -62,8 +63,14 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
62
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
63
  pass
64
 
65
- def on_text(self, text: str, **kwargs) -> None:
66
- if text.startswith('SELECT'):
 
 
 
 
 
 
67
  st.write('We generated Vector SQL for you:')
68
  st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
69
  print(f"Vector SQL: {text}")
@@ -83,4 +90,4 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
83
  self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
84
  self.status_bar = st.empty()
85
  self.prog_value = 0
86
- self.prog_interval = 0.1
 
2
  from typing import Dict, Any
3
  from sql_formatter.core import format_sql
4
  from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
5
+ from langchain.schema.output import LLMResult
6
 
7
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
8
  def __init__(self) -> None:
 
63
  def on_llm_start(self, serialized, prompts, **kwargs) -> None:
64
  pass
65
 
66
+ def on_llm_end(
67
+ self,
68
+ response: LLMResult,
69
+ *args,
70
+ **kwargs,
71
+ ):
72
+ text = response.generations[0][0].text
73
+ if text.replace(' ', '').upper().startswith('SELECT'):
74
  st.write('We generated Vector SQL for you:')
75
  st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
76
  print(f"Vector SQL: {text}")
 
90
  self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
91
  self.status_bar = st.empty()
92
  self.prog_value = 0
93
+ self.prog_interval = 0.1