jupyterjazz commited on
Commit
5418705
1 Parent(s): a2b7c86

refactor: disable lora by default

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (2) hide show
  1. configuration_xlm_roberta.py +2 -0
  2. modeling_lora.py +3 -3
configuration_xlm_roberta.py CHANGED
@@ -26,6 +26,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
26
  lora_rank=4,
27
  lora_dropout_p=0.0,
28
  lora_alpha=1,
 
29
  load_trained_adapters=False,
30
  use_flash_attn=True,
31
  torch_dtype=None,
@@ -55,6 +56,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
55
  self.lora_rank = lora_rank
56
  self.lora_dropout_p = lora_dropout_p
57
  self.lora_alpha = lora_alpha
 
58
  self.use_flash_attn = use_flash_attn
59
  self.emb_pooler = emb_pooler
60
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
 
26
  lora_rank=4,
27
  lora_dropout_p=0.0,
28
  lora_alpha=1,
29
+ lora_main_params_trainable=False,
30
  load_trained_adapters=False,
31
  use_flash_attn=True,
32
  torch_dtype=None,
 
56
  self.lora_rank = lora_rank
57
  self.lora_dropout_p = lora_dropout_p
58
  self.lora_alpha = lora_alpha
59
+ self.lora_main_params_trainable = lora_main_params_trainable
60
  self.use_flash_attn = use_flash_attn
61
  self.emb_pooler = emb_pooler
62
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
modeling_lora.py CHANGED
@@ -226,10 +226,10 @@ class XLMRobertaLoRA(XLMRobertaModel):
226
  dropout_p=self._dropout_p,
227
  alpha=self._alpha,
228
  )
229
- self.main_params_trainable = False
230
  self._task_idx = None
231
- # By default, we select the first LoRA
232
- self.current_task = 0
233
 
234
 
235
  @property
 
226
  dropout_p=self._dropout_p,
227
  alpha=self._alpha,
228
  )
229
+ self.main_params_trainable = config.lora_main_params_trainable
230
  self._task_idx = None
231
+ # By default, disable LoRA until it's specified which adapter/task to use
232
+ self.current_task = None
233
 
234
 
235
  @property