mpsk commited on
Commit
f0781cb
1 Parent(s): e2fd469

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -11
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import esm
4
  import requests
5
  import matplotlib.pyplot as plt
6
- from myscaledb import Client
7
  import random
8
  from collections import Counter
9
  from tqdm import tqdm
@@ -47,8 +47,6 @@ def init_db():
47
  host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"],
48
  interface=r['http_pre'],
49
  )
50
- # We can check if the connection is alive
51
- assert client.is_alive()
52
  meta_field = {}
53
  return meta_field, Client
54
 
@@ -126,10 +124,8 @@ def esm_search(model, sequnce, batch_converter,top_k=5):
126
 
127
  token_list = token_representations.tolist()[0][0][0]
128
 
129
- client = Client(
130
- url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
131
-
132
- result = client.fetch("SELECT seq, distance(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768 ORDER BY dist LIMIT 500")
133
 
134
  result_temp_seq = []
135
 
@@ -167,10 +163,10 @@ def KNN_search(sequence):
167
  token_representations = results["representations"][33]
168
  token_list = token_representations.tolist()[0][0]
169
  print(token_list)
170
- client = Client(
171
- url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
172
 
173
- result = client.fetch("SELECT activity, distance('topK=10')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer")
 
 
174
  result_temp_activity = []
175
  for i in result:
176
  # print(result_temp_seq)
@@ -286,7 +282,7 @@ def init_random_query():
286
 
287
 
288
  with st.spinner("Connecting DB..."):
289
- st.session_state.meta, client = init_db()
290
 
291
  with st.spinner("Loading Models..."):
292
  # Initialize SAGE model
 
3
  import esm
4
  import requests
5
  import matplotlib.pyplot as plt
6
+ from clickhouse_connect import get_client
7
  import random
8
  from collections import Counter
9
  from tqdm import tqdm
 
47
  host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"],
48
  interface=r['http_pre'],
49
  )
 
 
50
  meta_field = {}
51
  return meta_field, Client
52
 
 
124
 
125
  token_list = token_representations.tolist()[0][0][0]
126
 
127
+ result = st.session_state.client.query("SELECT seq, distance(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768 ORDER BY dist LIMIT 500")
128
+ result = [r for r in result.named_results()]
 
 
129
 
130
  result_temp_seq = []
131
 
 
163
  token_representations = results["representations"][33]
164
  token_list = token_representations.tolist()[0][0]
165
  print(token_list)
 
 
166
 
167
+ result = st.session_state.client.query("SELECT activity, distance(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer ORDER BY dist LIMIT 10")
168
+ result = [r for r in result.named_results()]
169
+
170
  result_temp_activity = []
171
  for i in result:
172
  # print(result_temp_seq)
 
282
 
283
 
284
  with st.spinner("Connecting DB..."):
285
+ st.session_state.meta, st.session_state.client = init_db()
286
 
287
  with st.spinner("Loading Models..."):
288
  # Initialize SAGE model