|
import streamlit as st |
|
import torch |
|
import esm |
|
import requests |
|
import matplotlib.pyplot as plt |
|
from myscaledb import Client |
|
import random |
|
from collections import Counter |
|
from tqdm import tqdm |
|
from statistics import mean |
|
import biotite.structure.io as bsio |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
from stmol import * |
|
import py3Dmol |
|
|
|
|
|
import scipy |
|
from sklearn.model_selection import GridSearchCV, train_test_split |
|
from sklearn.decomposition import PCA |
|
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor |
|
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
|
from sklearn.linear_model import LogisticRegression, SGDRegressor |
|
from sklearn.pipeline import Pipeline |
|
|
|
from streamlit.components.v1 import html |
|
|
|
|
|
def init_esm(): |
|
msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() |
|
msa_transformer = msa_transformer.eval() |
|
return msa_transformer, msa_transformer_alphabet |
|
|
|
@st.experimental_singleton(show_spinner=False) |
|
def init_db(): |
|
""" Initialize the Database Connection |
|
|
|
Returns: |
|
meta_field: Meta field that records if an image is viewed |
|
client: Database connection object |
|
""" |
|
client = Client( |
|
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) |
|
|
|
assert client.is_alive() |
|
meta_field = {} |
|
return meta_field, Client |
|
|
|
|
|
def perdict_contact_visualization(seq, model, batch_converter): |
|
data = [ |
|
("protein1", seq), |
|
] |
|
batch_labels, batch_strs, batch_tokens = batch_converter(data) |
|
|
|
|
|
with torch.no_grad(): |
|
results = model(batch_tokens, repr_layers=[12], return_contacts=True) |
|
token_representations = results["representations"][12] |
|
|
|
|
|
|
|
|
|
sequence_representations = [] |
|
for i, (_, seq) in enumerate(data): |
|
sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0)) |
|
|
|
|
|
for (_, seq), attention_contacts in zip(data, results["contacts"]): |
|
fig, ax = plt.subplots() |
|
ax.matshow(attention_contacts[: len(seq), : len(seq)]) |
|
|
|
|
|
|
|
return fig |
|
|
|
|
|
def visualize_3D_Coordinates(coords): |
|
xs = [] |
|
ys = [] |
|
zs = [] |
|
for i in coords: |
|
xs.append(i[0]) |
|
ys.append(i[1]) |
|
zs.append(i[2]) |
|
fig = plt.figure(figsize=(10,10)) |
|
ax = fig.add_subplot(111, projection='3d') |
|
ax.set_title('3D coordinates of $C_{b}$ backbone structure') |
|
N = len(coords) |
|
for i in range(len(coords) - 1): |
|
ax.plot( |
|
xs[i:i+2], ys[i:i+2], zs[i:i+2], |
|
color=plt.cm.viridis(i/N), |
|
marker='o' |
|
) |
|
return fig |
|
|
|
def render_mol(pdb): |
|
pdbview = py3Dmol.view() |
|
pdbview.addModel(pdb,'pdb') |
|
pdbview.setStyle({'cartoon':{'color':'spectrum'}}) |
|
pdbview.setBackgroundColor('white') |
|
pdbview.zoomTo() |
|
pdbview.zoom(2, 800) |
|
pdbview.spin(True) |
|
showmol(pdbview, height = 500,width=800) |
|
|
|
|
|
|
|
def esm_search(model, sequnce, batch_converter,top_k=5): |
|
data = [ |
|
("protein1", sequnce), |
|
] |
|
batch_labels, batch_strs, batch_tokens = batch_converter(data) |
|
|
|
|
|
with torch.no_grad(): |
|
results = model(batch_tokens, repr_layers=[12], return_contacts=True) |
|
token_representations = results["representations"][12] |
|
|
|
token_list = token_representations.tolist()[0][0][0] |
|
|
|
client = Client( |
|
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) |
|
|
|
result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768") |
|
|
|
result_temp_seq = [] |
|
|
|
for i in result: |
|
|
|
result_temp_seq.append(i['seq']) |
|
|
|
result_temp_seq = list(set(result_temp_seq)) |
|
|
|
return result_temp_seq |
|
|
|
def show_protein_structure(sequence): |
|
headers = { |
|
'Content-Type': 'application/x-www-form-urlencoded', |
|
} |
|
response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) |
|
name = sequence[:3] + sequence[-3:] |
|
pdb_string = response.content.decode('utf-8') |
|
with open('predicted.pdb', 'w') as f: |
|
f.write(pdb_string) |
|
struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) |
|
b_value = round(struct.b_factor.mean(), 4) |
|
render_mol(pdb_string) |
|
|
|
def KNN_search(sequence): |
|
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
batch_converter = alphabet.get_batch_converter() |
|
model.eval() |
|
data = [("protein1", sequence), |
|
] |
|
batch_labels, batch_strs, batch_tokens = batch_converter(data) |
|
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) |
|
with torch.no_grad(): |
|
results = model(batch_tokens, repr_layers=[33], return_contacts=True) |
|
token_representations = results["representations"][33] |
|
token_list = token_representations.tolist()[0][0] |
|
print(token_list) |
|
client = Client( |
|
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) |
|
|
|
result = client.fetch("SELECT activity, distance('topK=10')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer") |
|
result_temp_activity = [] |
|
for i in result: |
|
|
|
result_temp_activity.append(i['activity']) |
|
|
|
res_1 = sum(result_temp_activity)/len(result_temp_activity) |
|
return res_1 |
|
|
|
|
|
|
|
def train_test_split_PCA(dataset): |
|
ys = [] |
|
Xs = [] |
|
FASTA_PATH = '/root/xuying_experiments/esm-main/P62593.fasta' |
|
EMB_PATH = '/root/xuying_experiments/esm-main/P62593_reprs' |
|
for header, _seq in esm.data.read_fasta(FASTA_PATH): |
|
scaled_effect = header.split('|')[-1] |
|
ys.append(float(scaled_effect)) |
|
fn = f'{EMB_PATH}/{header}.pt' |
|
embs = torch.load(fn) |
|
Xs.append(embs['mean_representations'][34]) |
|
|
|
Xs = torch.stack(Xs, dim=0).numpy() |
|
train_size = 0.8 |
|
Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42) |
|
return Xs_train, Xs_test, ys_train, ys_test |
|
|
|
def PCA_visual(Xs_train): |
|
num_pca_components = 60 |
|
pca = PCA(num_pca_components) |
|
Xs_train_pca = pca.fit_transform(Xs_train) |
|
fig_dims = (4, 4) |
|
fig, ax = plt.subplots(figsize=fig_dims) |
|
ax.set_title('Visualize Embeddings') |
|
sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.') |
|
ax.set_xlabel('PCA first principal component') |
|
ax.set_ylabel('PCA second principal component') |
|
plt.colorbar(sc, label='Variant Effect') |
|
|
|
return fig |
|
|
|
def KNN_trainings(Xs_train, Xs_test, ys_train, ys_test): |
|
num_pca_components = 60 |
|
knn_grid = [ |
|
{ |
|
'model': [KNeighborsRegressor()], |
|
'model__n_neighbors': [5, 10], |
|
'model__weights': ['uniform', 'distance'], |
|
'model__algorithm': ['ball_tree', 'kd_tree', 'brute'], |
|
'model__leaf_size' : [15, 30], |
|
'model__p' : [1, 2], |
|
}] |
|
|
|
cls_list = [KNeighborsRegressor] |
|
param_grid_list = [knn_grid] |
|
|
|
pipe = Pipeline( |
|
steps = ( |
|
('pca', PCA(num_pca_components)), |
|
('model', KNeighborsRegressor()) |
|
) |
|
) |
|
|
|
result_list = [] |
|
grid_list = [] |
|
|
|
for cls_name, param_grid in zip(cls_list, param_grid_list): |
|
print(cls_name) |
|
grid = GridSearchCV( |
|
estimator = pipe, |
|
param_grid = param_grid, |
|
scoring = 'r2', |
|
verbose = 1, |
|
n_jobs = -1 |
|
) |
|
grid.fit(Xs_train, ys_train) |
|
|
|
result_list.append(pd.DataFrame.from_dict(grid.cv_results_)) |
|
grid_list.append(grid) |
|
|
|
dataframe = pd.DataFrame(result_list[0].sort_values('rank_test_score')[:5]) |
|
|
|
|
|
return dataframe[['param_model','params','param_model__algorithm','mean_test_score','rank_test_score']] |
|
|
|
|
|
st.markdown(""" |
|
<link |
|
rel="stylesheet" |
|
href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap" |
|
/> |
|
""", unsafe_allow_html=True) |
|
|
|
messages = [ |
|
f""" |
|
Evolutionary-scale prediction of atomic level protein structure |
|
|
|
ESM is a high-capacity Transformer trained with protein sequences \ |
|
as input. After training, the secondary and tertiary structure, \ |
|
function, homology and other information of the protein are in the feature representation output by the model.\ |
|
Check out https://esmatlas.com/ for more information. |
|
|
|
We have 120k proteins features stored in our database. |
|
|
|
The app uses MyScale to store and query protein sequence |
|
using vector search. |
|
""" |
|
] |
|
@st.experimental_singleton(show_spinner=False) |
|
def init_random_query(): |
|
xq = np.random.rand(DIMS).tolist() |
|
return xq, xq.copy() |
|
|
|
|
|
with st.spinner("Connecting DB..."): |
|
st.session_state.meta, client = init_db() |
|
|
|
with st.spinner("Loading Models..."): |
|
|
|
if 'xq' not in st.session_state: |
|
model, alphabet = init_esm() |
|
batch_converter = alphabet.get_batch_converter() |
|
st.session_state['batch'] = batch_converter |
|
st.session_state.query_num = 0 |
|
|
|
if 'xq' not in st.session_state: |
|
|
|
if st.session_state.query_num < len(messages): |
|
msg = messages[0] |
|
else: |
|
msg = messages[-1] |
|
|
|
|
|
with st.container(): |
|
st.title("Evolutionary Scale Modeling") |
|
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] |
|
start[0].info(msg) |
|
function_list = ('self-contact prediction', |
|
'search the database for similar proteins', |
|
'activity prediction with similar proteins', |
|
'PDB viewer') |
|
option = st.selectbox('Application options', function_list) |
|
|
|
st.session_state.db_name_ref = 'default.esm_protein' |
|
if option == function_list[0]: |
|
sequence = st.text_input('protein sequence', '') |
|
if st.button('Cas9 Enzyme'): |
|
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' |
|
elif st.button('PETase'): |
|
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' |
|
|
|
|
|
if sequence: |
|
st.write('') |
|
start[2] = st.pyplot(perdict_contact_visualization(sequence, model, batch_converter)) |
|
expander = st.expander("See explanation") |
|
expander.text("""Contact prediction is based on a logistic regression over the model's attention maps. \ |
|
This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. |
|
(Rao et al. 2020) The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.""") |
|
st.session_state['xq'] = model |
|
elif option == function_list[1]: |
|
sequence = st.text_input('protein sequence', '') |
|
st.write('Try an example:') |
|
if st.button('Cas9 Enzyme'): |
|
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' |
|
elif st.button('PETase'): |
|
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' |
|
|
|
if sequence: |
|
st.write('you have entered: ', sequence) |
|
result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) |
|
st.text('search result: ') |
|
|
|
if st.button(result_temp_seq[0]): |
|
print(result_temp_seq[0]) |
|
elif st.button(result_temp_seq[1]): |
|
print(result_temp_seq[1]) |
|
elif st.button(result_temp_seq[2]): |
|
print(result_temp_seq[2]) |
|
elif st.button(result_temp_seq[3]): |
|
print(result_temp_seq[3]) |
|
elif st.button(result_temp_seq[4]): |
|
print(result_temp_seq[4]) |
|
|
|
start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) |
|
st.session_state['xq'] = model |
|
elif option == function_list[2]: |
|
st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') |
|
sequence = st.text_input('protein sequence', '') |
|
st.write('Try an example:') |
|
if st.button('Cas9 Enzyme'): |
|
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' |
|
elif st.button('PETase'): |
|
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' |
|
|
|
elif option == function_list[3]: |
|
id_PDB = st.text_input('enter PDB ID', '') |
|
residues_marker = st.text_input('residues class', '') |
|
if residues_marker: |
|
start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) |
|
else: |
|
start[3] = showmol(render_pdb(id = id_PDB)) |
|
st.session_state['xq'] = model |
|
|
|
else: |
|
if st.session_state.query_num < len(messages): |
|
msg = messages[0] |
|
else: |
|
msg = messages[-1] |
|
|
|
|
|
with st.container(): |
|
st.title("Evolutionary Scale Modeling") |
|
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] |
|
start[0].info(msg) |
|
option = st.selectbox('Application options', ('self-contact prediction', 'search the database', 'activity prediction','PDB viewer')) |
|
|
|
st.session_state.db_name_ref = 'default.esm_protein' |
|
if option == 'self-contact prediction': |
|
sequence = st.text_input('protein sequence', '') |
|
if st.button('Cas9 Enzyme'): |
|
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' |
|
elif st.button('PETase'): |
|
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' |
|
|
|
|
|
if sequence: |
|
st.write('you have entered: ',sequence) |
|
start[2] = st.pyplot(perdict_contact_visualization(sequence, st.session_state['xq'], st.session_state['batch'])) |
|
expander = st.expander("See explanation") |
|
expander.markdown( |
|
"""<span style="word-wrap:break-word;">Contact prediction is based on a logistic regression over the model's attention maps. This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020)The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.</span> |
|
""", unsafe_allow_html=True) |
|
elif option == 'search the database': |
|
sequence = st.text_input('protein sequence', '') |
|
st.write('Try an example:') |
|
if st.button('Cas9 Enzyme'): |
|
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' |
|
elif st.button('PETase'): |
|
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' |
|
|
|
if sequence: |
|
st.write('you have entered: ', sequence) |
|
st.session_state['xq'] = model |
|
batch_converter = alphabet.get_batch_converter() |
|
|
|
result_temp_seq = esm_search(st.session_state['xq'], sequence, batch_converter ,top_k=10) |
|
st.text('search result (top 5): ') |
|
|
|
tab1, tab2, tab3 , tab4, tab5 = st.tabs(['1','2','3','4','5']) |
|
|
|
with tab1: |
|
st.write(result_temp_seq[0]) |
|
show_protein_structure(result_temp_seq[0]) |
|
with tab2: |
|
st.write(result_temp_seq[1]) |
|
show_protein_structure(result_temp_seq[1]) |
|
with tab3: |
|
st.write(result_temp_seq[2]) |
|
show_protein_structure(result_temp_seq[2]) |
|
with tab4: |
|
st.write(result_temp_seq[3]) |
|
show_protein_structure(result_temp_seq[3]) |
|
with tab5: |
|
st.write(result_temp_seq[4]) |
|
show_protein_structure(result_temp_seq[4]) |
|
|
|
|
|
elif option == 'activity prediction': |
|
st.markdown('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') |
|
|
|
sequence = st.text_input('protein sequence', '') |
|
st.write('Try an example:') |
|
if st.button('Cas9 Enzyme'): |
|
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' |
|
elif st.button('PETase'): |
|
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' |
|
if sequence: |
|
st.write('you have entered: ',sequence) |
|
res_knn = KNN_search(sequence) |
|
st.subheader('KNN predictor result') |
|
start[2] = st.markdown("Activity prediction: " + str(res_knn)) |
|
|
|
|
|
elif option == 'PDB viewer': |
|
id_PDB = st.text_input('enter PDB ID', '') |
|
residues_marker = st.text_input('residues class', '') |
|
st.write('Try an example:') |
|
if st.button('PDB ID: 1A2C / residues class: ALA'): |
|
id_PDB = '1A2C' |
|
residues_marker = 'ALA' |
|
|
|
st.subheader('PDB viewer') |
|
if residues_marker: |
|
start[7] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) |
|
else: |
|
start[7] = showmol(render_pdb(id = id_PDB)) |
|
|
|
expander = st.expander("See explanation") |
|
expander.markdown(""" |
|
A PDB ID is a unique 4-character code for each entry in the Protein Data Bank. The first character must be a number between 1 and 9, and the remaining three characters can be letters or numbers. |
|
see https://www.rcsb.org/ for more information. |
|
""") |