yinuozhang commited on
Commit
071db43
1 Parent(s): 01a1f08
Files changed (1) hide show
  1. model.py +13 -3
model.py CHANGED
@@ -14,7 +14,7 @@ import gc
14
  from torch.optim.lr_scheduler import _LRScheduler
15
  from transformers import EsmModel, PreTrainedModel
16
  from configuration import MetaLATTEConfig
17
-
18
  seed_everything(42)
19
 
20
  class GELU(nn.Module):
@@ -226,9 +226,19 @@ class MultitaskProteinModel(PreTrainedModel):
226
  config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
227
 
228
  model = cls(config)
229
- state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
230
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
 
231
  return model
 
232
 
233
  def forward(self, input_ids, attention_mask=None):
234
  outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
 
14
  from torch.optim.lr_scheduler import _LRScheduler
15
  from transformers import EsmModel, PreTrainedModel
16
  from configuration import MetaLATTEConfig
17
+ from urllib.parse import urljoin
18
  seed_everything(42)
19
 
20
  class GELU(nn.Module):
 
226
  config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
227
 
228
  model = cls(config)
229
+ #state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
230
+ try:
231
+ state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin")
232
+ state_dict = torch.hub.load_state_dict_from_url(
233
+ state_dict_url,
234
+ map_location=torch.device('cpu')
235
+ )['state_dict']
236
+ model.load_state_dict(state_dict, strict=False)
237
+ except Exception as e:
238
+ raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}")
239
+
240
  return model
241
+
242
 
243
  def forward(self, input_ids, attention_mask=None):
244
  outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)