Jackmin108 commited on
Commit
aeb99cb
1 Parent(s): 8f83a35

refactor: prompts

Browse files

Signed-off-by: Meow <[email protected]>

Files changed (1) hide show
  1. modeling_lora.py +4 -2
modeling_lora.py CHANGED
@@ -165,7 +165,6 @@ class LoRAParametrization(nn.Module):
165
  ):
166
  """
167
  Registering LoRA adapters to all embedding and linear layers.
168
-
169
  Additionally, we implement a custom forward function for LoRA parametrization.
170
  This function modifies the layer's forward pass to optionally use task-specific
171
  parameters. When a `task_id` is provided, it employs a LoRA parametrization
@@ -373,7 +372,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
373
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
374
  """
375
  Computes sentence embeddings.
376
-
377
  sentences(`str` or `List[str]`):
378
  Sentence or sentences to be encoded
379
  task_type(`str`, *optional*, defaults to `None`):
@@ -394,6 +392,10 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
394
  adapter_mask = torch.full(
395
  (num_examples,), task_id, dtype=torch.int32, device=self.device
396
  )
 
 
 
 
397
  return self.roberta.encode(
398
  sentences, *args, adapter_mask=adapter_mask, **kwargs
399
  )
 
165
  ):
166
  """
167
  Registering LoRA adapters to all embedding and linear layers.
 
168
  Additionally, we implement a custom forward function for LoRA parametrization.
169
  This function modifies the layer's forward pass to optionally use task-specific
170
  parameters. When a `task_id` is provided, it employs a LoRA parametrization
 
372
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
373
  """
374
  Computes sentence embeddings.
 
375
  sentences(`str` or `List[str]`):
376
  Sentence or sentences to be encoded
377
  task_type(`str`, *optional*, defaults to `None`):
 
392
  adapter_mask = torch.full(
393
  (num_examples,), task_id, dtype=torch.int32, device=self.device
394
  )
395
+ if isinstance(sentences, str):
396
+ sentences = self._task_instructions[task_type] + sentences
397
+ else:
398
+ sentences = [self._task_instructions[task_type] + sentence for sentence in sentences]
399
  return self.roberta.encode(
400
  sentences, *args, adapter_mask=adapter_mask, **kwargs
401
  )