mpsk commited on
Commit
04f0bde
β€’
1 Parent(s): 5ee41cc

add parse and private knowledge base

Browse files
Files changed (5) hide show
  1. app.py +2 -2
  2. chat.py +63 -9
  3. helper.py +7 -4
  4. lib/private_kb.py +138 -0
  5. lib/sessions.py +10 -2
app.py CHANGED
@@ -28,8 +28,8 @@ st.markdown(
28
  )
29
  st.header("ChatData")
30
 
31
- if 'sel_map_obj' not in st.session_state:
32
- st.session_state["sel_map_obj"] = build_all()
33
  st.session_state["tools"] = build_tools()
34
 
35
  if login():
 
28
  )
29
  st.header("ChatData")
30
 
31
+ if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state:
32
+ st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all()
33
  st.session_state["tools"] = build_tools()
34
 
35
  if login():
chat.py CHANGED
@@ -4,6 +4,7 @@ from time import sleep
4
  import datetime
5
  import streamlit as st
6
  from lib.sessions import SessionManager
 
7
  from langchain.schema import HumanMessage, FunctionMessage
8
  from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
9
  from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
@@ -15,6 +16,7 @@ from helper import (
15
  MYSCALE_PORT,
16
  MYSCALE_USER,
17
  DEFAULT_SYSTEM_PROMPT,
 
18
  )
19
  from login import back_to_main
20
 
@@ -49,6 +51,10 @@ def back_to_main():
49
  del st.session_state.user_name
50
  if "jump_query_ask" in st.session_state:
51
  del st.session_state.jump_query_ask
 
 
 
 
52
 
53
 
54
  def on_session_change_submit():
@@ -87,6 +93,7 @@ def on_session_change_submit():
87
 
88
  def build_session_manager():
89
  return SessionManager(
 
90
  host=MYSCALE_HOST,
91
  port=MYSCALE_PORT,
92
  username=MYSCALE_USER,
@@ -130,7 +137,23 @@ def refresh_agent():
130
  if "sel_sess" not in st.session_state
131
  else st.session_state.sel_sess["system_prompt"],
132
  )
133
- st.session_state["session_manager"] = build_session_manager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
136
  def chat_page():
@@ -139,7 +162,17 @@ def chat_page():
139
  "session_id": "default",
140
  "system_prompt": DEFAULT_SYSTEM_PROMPT,
141
  }
142
- st.session_state["session_manager"] = build_session_manager()
 
 
 
 
 
 
 
 
 
 
143
  with st.sidebar:
144
  with st.expander("Session Management"):
145
  if "current_sessions" not in st.session_state:
@@ -179,13 +212,34 @@ def chat_page():
179
  with st.expander("Tool Settings", expanded=True):
180
  st.info("Here you can select your tools.", icon="πŸ”§")
181
  st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="πŸ‘·β€β™‚οΈ")
182
- st.multiselect(
183
- "Knowledge Base",
184
- st.session_state.tools.keys(),
185
- default=["Wikipedia + Self Querying"],
186
- key="selected_tools",
187
- on_change=refresh_agent,
188
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  st.button("Clear Chat History", on_click=clear_history)
190
  st.button("Logout", on_click=back_to_main)
191
  if 'agent' not in st.session_state:
 
4
  import datetime
5
  import streamlit as st
6
  from lib.sessions import SessionManager
7
+ from lib.private_kb import PrivateKnowledgeBase
8
  from langchain.schema import HumanMessage, FunctionMessage
9
  from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
10
  from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
 
16
  MYSCALE_PORT,
17
  MYSCALE_USER,
18
  DEFAULT_SYSTEM_PROMPT,
19
+ UNSTRUCTURED_API,
20
  )
21
  from login import back_to_main
22
 
 
51
  del st.session_state.user_name
52
  if "jump_query_ask" in st.session_state:
53
  del st.session_state.jump_query_ask
54
+ if "sel_sess" in st.session_state:
55
+ del st.session_state.sel_sess
56
+ if "current_sessions" in st.session_state:
57
+ del st.session_state.current_sessions
58
 
59
 
60
  def on_session_change_submit():
 
93
 
94
  def build_session_manager():
95
  return SessionManager(
96
+ st.session_state,
97
  host=MYSCALE_HOST,
98
  port=MYSCALE_PORT,
99
  username=MYSCALE_USER,
 
137
  if "sel_sess" not in st.session_state
138
  else st.session_state.sel_sess["system_prompt"],
139
  )
140
+
141
+ def add_file():
142
+ if 'uploaded_files' not in st.session_state or len(st.session_state.uploaded_files) == 0:
143
+ st.session_state.tool_status.error("Please upload files!", icon="⚠️")
144
+ sleep(2)
145
+ return
146
+ try:
147
+ st.session_state.tool_status.info("Uploading...")
148
+ print([(f.name, f.type) for f in st.session_state.uploaded_files])
149
+ st.session_state.private_kb.add_by_file(st.session_state.user_name,
150
+ st.session_state.uploaded_files)
151
+ except ValueError as e:
152
+ st.session_state.tool_status.error("Failed to upload! " + str(e))
153
+ sleep(2)
154
+
155
+ def clear_files():
156
+ st.session_state.private_kb.clear(st.session_state.user_name)
157
 
158
 
159
  def chat_page():
 
162
  "session_id": "default",
163
  "system_prompt": DEFAULT_SYSTEM_PROMPT,
164
  }
165
+ if "private_kb" not in st.session_state:
166
+ st.session_state["private_kb"] = PrivateKnowledgeBase(
167
+ host=MYSCALE_HOST,
168
+ port=MYSCALE_PORT,
169
+ username=MYSCALE_USER,
170
+ password=MYSCALE_PASSWORD,
171
+ embedding=st.session_state.embeddings['Wikipedia'],
172
+ parser_api_key=UNSTRUCTURED_API,
173
+ )
174
+ if "session_manager" not in st.session_state:
175
+ st.session_state["session_manager"] = build_session_manager()
176
  with st.sidebar:
177
  with st.expander("Session Management"):
178
  if "current_sessions" not in st.session_state:
 
212
  with st.expander("Tool Settings", expanded=True):
213
  st.info("Here you can select your tools.", icon="πŸ”§")
214
  st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="πŸ‘·β€β™‚οΈ")
215
+ st.session_state["tool_status"] = st.empty()
216
+ tab_kb, tab_file, tab_build = st.tabs(["Knowledge Bases", "File Upload", "KB Builder"])
217
+ with tab_kb:
218
+ st.multiselect(
219
+ "Select a Knowledge Base Tool",
220
+ st.session_state.tools.keys(),
221
+ default=["Wikipedia + Self Querying"],
222
+ key="selected_tools",
223
+ on_change=refresh_agent,
224
+ )
225
+ with tab_file:
226
+ st.file_uploader("Upload files", key="uploaded_files", accept_multiple_files=True)
227
+ st.markdown("### Uploaded Files")
228
+ st.dataframe(st.session_state.private_kb.list_files(st.session_state.user_name))
229
+ col_1, col_2 = st.columns(2)
230
+ with col_1:
231
+ st.button("Add Files", on_click=add_file)
232
+ with col_2:
233
+ st.button("Clear Files", on_click=clear_files)
234
+ # with tab_build:
235
+ # st.text_input("Give this knowledge base a description:")
236
+ # col_3, col_4 = st.columns(2)
237
+ # with col_3:
238
+ # st.button("Build Your KB!")
239
+ # with col_4:
240
+ # st.button("Delete Your KB")
241
+
242
+
243
  st.button("Clear Chat History", on_click=clear_history)
244
  st.button("Logout", on_click=back_to_main)
245
  if 'agent' not in st.session_state:
helper.py CHANGED
@@ -2,7 +2,7 @@
2
  import json
3
  import time
4
  import hashlib
5
- from typing import Dict, Any, List
6
  import re
7
  import pandas as pd
8
  from os import environ
@@ -67,6 +67,7 @@ MYSCALE_USER = st.secrets['MYSCALE_USER']
67
  MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
68
  MYSCALE_HOST = st.secrets['MYSCALE_HOST']
69
  MYSCALE_PORT = st.secrets['MYSCALE_PORT']
 
70
 
71
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
72
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
@@ -348,17 +349,19 @@ def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query")
348
  return chain
349
 
350
  @st.cache_resource
351
- def build_all() -> Dict[str, Any]:
352
  """build all resources
353
 
354
  :return: sel_map_obj
355
  :rtype: Dict[str, Any]
356
  """
357
  sel_map_obj = {}
 
358
  for k in sel_map:
359
- st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
 
360
  sel_map_obj[k] = build_chains_retrievers(k)
361
- return sel_map_obj
362
 
363
  def create_message_model(table_name, DynamicBase): # type: ignore
364
  """
 
2
  import json
3
  import time
4
  import hashlib
5
+ from typing import Dict, Any, List, Tuple
6
  import re
7
  import pandas as pd
8
  from os import environ
 
67
  MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
68
  MYSCALE_HOST = st.secrets['MYSCALE_HOST']
69
  MYSCALE_PORT = st.secrets['MYSCALE_PORT']
70
+ UNSTRUCTURED_API = st.secrets['UNSTRUCTURED_API']
71
 
72
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
73
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
 
349
  return chain
350
 
351
  @st.cache_resource
352
+ def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
353
  """build all resources
354
 
355
  :return: sel_map_obj
356
  :rtype: Dict[str, Any]
357
  """
358
  sel_map_obj = {}
359
+ embeddings = {}
360
  for k in sel_map:
361
+ embeddings[k] = build_embedding_model(k)
362
+ st.session_state[f'emb_model_{k}'] = embeddings[k]
363
  sel_map_obj[k] = build_chains_retrievers(k)
364
+ return sel_map_obj, embeddings
365
 
366
  def create_message_model(table_name, DynamicBase): # type: ignore
367
  """
lib/private_kb.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import hashlib
3
+ import requests
4
+ from typing import List
5
+ from datetime import datetime
6
+ from langchain.schema.embeddings import Embeddings
7
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
8
+ from clickhouse_connect import get_client
9
+ from multiprocessing.pool import ThreadPool
10
+ from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
11
+
12
+ parser_url = "https://api.unstructured.io/general/v0/general"
13
+
14
+
15
+ def parse_files(api_key, user_id, files: List[UploadedFile], collection="default"):
16
+ def parse_file(file: UploadedFile):
17
+ headers = {
18
+ "accept": "application/json",
19
+ "unstructured-api-key": api_key,
20
+ }
21
+ data = {"strategy": "auto", "ocr_languages": ["eng"]}
22
+ file_hash = hashlib.sha256(file.read()).hexdigest()
23
+ file_data = {"files": (file.name, file.getvalue(), file.type)}
24
+ response = requests.post(
25
+ parser_url, headers=headers, data=data, files=file_data
26
+ )
27
+ json_response = response.json()
28
+ if response.status_code != 200:
29
+ raise ValueError(str(json_response))
30
+ texts = [
31
+ {
32
+ "text": t["text"],
33
+ "file_name": t["metadata"]["filename"],
34
+ "entity_id": hashlib.sha256((file_hash + t["text"]).encode()).hexdigest(),
35
+ "user_id": user_id,
36
+ "collection_id": collection,
37
+ "created_by": datetime.now(),
38
+ }
39
+ for t in json_response
40
+ if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
41
+ ]
42
+ return texts
43
+
44
+ with ThreadPool(8) as p:
45
+ rows = []
46
+ for r in map(parse_file, files):
47
+ rows.extend(r)
48
+ return rows
49
+
50
+
51
+ def extract_embedding(embeddings: Embeddings, texts):
52
+ if len(texts) > 0:
53
+ embs = embeddings.embed_documents([t["text"] for _, t in enumerate(texts)])
54
+ for i, _ in enumerate(texts):
55
+ texts[i]["vector"] = embs[i]
56
+ return texts
57
+ raise ValueError("No texts extracted!")
58
+
59
+
60
+ class PrivateKnowledgeBase:
61
+ def __init__(
62
+ self,
63
+ host,
64
+ port,
65
+ username,
66
+ password,
67
+ embedding: Embeddings,
68
+ parser_api_key,
69
+ db="chat",
70
+ kb_table="private_kb",
71
+ ) -> None:
72
+ super().__init__()
73
+ schema_ = f"""
74
+ CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
75
+ entity_id String,
76
+ file_name String,
77
+ text String,
78
+ user_id String,
79
+ collection_id String,
80
+ created_by DateTime,
81
+ vector Array(Float32),
82
+ CONSTRAINT cons_vec_len CHECK length(vector) = 768,
83
+ VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
84
+ ) ENGINE = ReplacingMergeTree ORDER BY entity_id
85
+ """
86
+ config = MyScaleSettings(
87
+ host=host,
88
+ port=port,
89
+ username=username,
90
+ password=password,
91
+ database=db,
92
+ table=kb_table,
93
+ )
94
+ client = get_client(
95
+ host=config.host,
96
+ port=config.port,
97
+ username=config.username,
98
+ password=config.password,
99
+ )
100
+ client.command("SET allow_experimental_object_type=1")
101
+ client.command(schema_)
102
+ self.parser_api_key = parser_api_key
103
+ self.vstore = MyScaleWithoutJSON(
104
+ embedding=embedding,
105
+ config=config,
106
+ must_have_cols=["file_name", "text", "create_by"],
107
+ )
108
+ self.retriever = self.vstore.as_retriever()
109
+
110
+ def list_files(self, user_id):
111
+ query = f"""
112
+ SELECT DISTINCT file_name FROM {self.vstore.config.database}.{self.vstore.config.table}
113
+ WHERE user_id = '{user_id}'
114
+ """
115
+ return [r for r in self.vstore.client.query(query).named_results()]
116
+
117
+ def add_by_file(
118
+ self, user_id, files: List[UploadedFile], collection="default", **kwargs
119
+ ):
120
+ data = parse_files(self.parser_api_key, user_id, files, collection=collection)
121
+ data = extract_embedding(self.vstore.embeddings, data)
122
+ self.vstore.client.insert_df(
123
+ self.vstore.config.table,
124
+ pd.DataFrame(data),
125
+ database=self.vstore.config.database,
126
+ )
127
+
128
+ def clear(self, user_id):
129
+ self.vstore.client.command(
130
+ f"DELETE FROM {self.vstore.config.database}.{self.vstore.config.table} "
131
+ f"WHERE user_id='{user_id}'"
132
+ )
133
+
134
+ def _get_relevant_documents(self, query, *args, **kwargs):
135
+ return self.retriever._get_relevant_documents(query, *args, **kwargs)
136
+
137
+ async def _aget_relevant_documents(self, *args, **kwargs):
138
+ return self.retriever._aget_relevant_documents(*args, **kwargs)
lib/sessions.py CHANGED
@@ -3,10 +3,12 @@ try:
3
  from sqlalchemy.orm import declarative_base
4
  except ImportError:
5
  from sqlalchemy.ext.declarative import declarative_base
 
6
  from datetime import datetime
7
  from sqlalchemy import Column, Text, orm, create_engine
8
  from clickhouse_sqlalchemy import types, engines
9
  from .schemas import create_message_model, create_session_table
 
10
 
11
  def get_sessions(engine, model_class, user_id):
12
  with orm.sessionmaker(engine)() as session:
@@ -20,7 +22,8 @@ def get_sessions(engine, model_class, user_id):
20
  return json.loads(result)
21
 
22
  class SessionManager:
23
- def __init__(self, host, port, username, password, db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
 
24
  conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
25
  self.engine = create_engine(conn_str, echo=False)
26
  self.sess_model_class = create_session_table(sess_table, declarative_base())
@@ -28,6 +31,7 @@ class SessionManager:
28
  self.msg_model_class = create_message_model(msg_table, declarative_base())
29
  self.msg_model_class.metadata.create_all(self.engine)
30
  self.Session = orm.sessionmaker(self.engine)
 
31
 
32
  def list_sessions(self, user_id):
33
  with self.Session() as session:
@@ -63,6 +67,10 @@ class SessionManager:
63
  def remove_session(self, session_id):
64
  with self.Session() as session:
65
  session.query(self.sess_model_class).where(self.sess_model_class.session_id==session_id).delete()
66
- session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
 
 
 
 
67
 
68
 
 
3
  from sqlalchemy.orm import declarative_base
4
  except ImportError:
5
  from sqlalchemy.ext.declarative import declarative_base
6
+ from langchain.schema import BaseChatMessageHistory
7
  from datetime import datetime
8
  from sqlalchemy import Column, Text, orm, create_engine
9
  from clickhouse_sqlalchemy import types, engines
10
  from .schemas import create_message_model, create_session_table
11
+ from .private_kb import PrivateKnowledgeBase
12
 
13
  def get_sessions(engine, model_class, user_id):
14
  with orm.sessionmaker(engine)() as session:
 
22
  return json.loads(result)
23
 
24
  class SessionManager:
25
+ def __init__(self, session_state, host, port, username, password,
26
+ db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
27
  conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
28
  self.engine = create_engine(conn_str, echo=False)
29
  self.sess_model_class = create_session_table(sess_table, declarative_base())
 
31
  self.msg_model_class = create_message_model(msg_table, declarative_base())
32
  self.msg_model_class.metadata.create_all(self.engine)
33
  self.Session = orm.sessionmaker(self.engine)
34
+ self.session_state = session_state
35
 
36
  def list_sessions(self, user_id):
37
  with self.Session() as session:
 
67
  def remove_session(self, session_id):
68
  with self.Session() as session:
69
  session.query(self.sess_model_class).where(self.sess_model_class.session_id==session_id).delete()
70
+ # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
71
+ if "agent" in self.session_state:
72
+ self.session_state.agent.memory.chat_memory.clear()
73
+ if "file_analyzer" in self.session_state:
74
+ self.session_state.file_analyzer.clear_files()
75
 
76