AmelieSchreiber commited on
Commit
575416b
1 Parent(s): 60da2ec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +110 -1
README.md CHANGED
@@ -1,9 +1,118 @@
1
  ---
2
  library_name: peft
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
4
  ## Training procedure
5
 
6
- ### Framework versions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
 
 
9
  - PEFT 0.4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: peft
3
+ license: mit
4
+ language:
5
+ - en
6
+ tags:
7
+ - transformers
8
+ - biology
9
+ - esm
10
+ - esm2
11
+ - protein
12
+ - protein language model
13
  ---
14
+ # ESM-2 RNA Binding Site LoRA
15
+
16
+ This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of
17
+ the [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) model for the (binary) token classification task of
18
+ predicting RNA binding sites of proteins. You can also find a version of this model
19
+ that was fine-tuned without LoRA [here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor).
20
+
21
  ## Training procedure
22
 
23
+ This is a Low Rank Adaptation (LoRA) of `esm2_t12_35M_UR50D`,
24
+ trained on `166` protein sequences in the [RNA binding sites dataset](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites)
25
+ using a `85/15` train/test split. This model was trained with class weighting due to the imbalanced nature
26
+ of the RNA binding site dataset (fewer binding sites than non-binding sites). This model has slightly improved
27
+ precision, recall, and F1 score over [AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding](https://huggingface.co/AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding)
28
+ but may suffer from mild overfitting, as indicated by the training loss being slightly lower than the eval loss. If you are searching for
29
+ binding sites and aren't worried about false positives, the higher recall may make this model preferable to the other RNA binding site predictors.
30
+
31
+ You can train your own version
32
+ using [this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding/blob/main/LoRA_binding_sites_no_sweeps_v2.ipynb)!
33
+ You just need the RNA `binding_sites.xml` file [found here](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites).
34
+ You may also need to run some `pip install` statements at the beginning of the script. If you are running in colab run:
35
+
36
+ ```python
37
+ !pip install transformers[torch] datasets peft -q
38
+ ```
39
+ ```python
40
+ !pip install accelerate -U -q
41
+ ```
42
+ Try to improve upon these metrics by adjusting the hyperparameters:
43
+ ```
44
+ {'eval_loss': 0.500779926776886,
45
+ 'eval_precision': 0.1708695652173913,
46
+ 'eval_recall': 0.8397435897435898,
47
+ 'eval_f1': 0.2839595375722543,
48
+ 'eval_auc': 0.771835775620126,
49
+ 'epoch': 11.0}
50
+ {'loss': 0.4171,
51
+ 'learning_rate': 0.00032491416877500004,
52
+ 'epoch': 11.43}
53
+ ```
54
+
55
+ A similar model can also be trained using the Github with a training script and conda env YAML, which can be
56
+ [found here](https://github.com/Amelie-Schreiber/esm2_LoRA_binding_sites/tree/main). This version uses wandb sweeps for hyperparameter search.
57
+ However, it does not use class weighting.
58
 
59
 
60
+ ### Framework versions
61
+
62
  - PEFT 0.4.0
63
+
64
+ ## Using the Model
65
+
66
+ To use the model, try running the following pip install statements:
67
+ ```python
68
+ !pip install transformers peft -q
69
+ ```
70
+ then try tunning:
71
+ ```python
72
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
73
+ from peft import PeftModel
74
+ import torch
75
+
76
+ # Path to the saved LoRA model
77
+ model_path = "AmelieSchreiber/esm2_t12_35M_UR50D_RNA_LoRA_weighted"
78
+ # ESM2 base model
79
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
80
+
81
+ # Load the model
82
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
83
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
84
+
85
+ # Ensure the model is in evaluation mode
86
+ loaded_model.eval()
87
+
88
+ # Load the tokenizer
89
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
90
+
91
+ # Protein sequence for inference
92
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
93
+
94
+ # Tokenize the sequence
95
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
96
+
97
+ # Run the model
98
+ with torch.no_grad():
99
+ logits = loaded_model(**inputs).logits
100
+
101
+ # Get predictions
102
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
103
+ predictions = torch.argmax(logits, dim=2)
104
+
105
+ # Define labels
106
+ id2label = {
107
+ 0: "No binding site",
108
+ 1: "Binding site"
109
+ }
110
+
111
+ # Print the predicted labels for each token
112
+ for token, prediction in zip(tokens, predictions[0].numpy()):
113
+ if token not in ['<pad>', '<cls>', '<eos>']:
114
+ print((token, id2label[prediction]))
115
+
116
+ ```
117
+
118
+