jupyterjazz commited on
Commit
4b000ec
1 Parent(s): f9b3adb

Refactor LoRA (#8)

Browse files

- refactor: lora (e0ea168f256f54e7728033ff8afbf9bb71c617cc)
- refactor: remove pooling layer stuff (c6a5a4d6aa2d39e1b0691e4f48869aa8d7e34b09)
- refactor: restructure the class (a2b7c8644033cc4318c6caa6730836023776faa9)
- refactor: disable lora by default (5418705c2b50051908c38d6df1055df4f21274a2)
- refactor: set task in lora class rather than xlm roberta (851aaca7b1e7f9dbffaacd3b070231e0f94401cb)
- refactor: stuff (370394630d973381185d21d153a5e46d3b9fc6da)

configuration_xlm_roberta.py CHANGED
@@ -22,7 +22,11 @@ class XLMRobertaFlashConfig(PretrainedConfig):
22
  position_embedding_type="absolute",
23
  use_cache=True,
24
  classifier_dropout=None,
25
- num_loras=1,
 
 
 
 
26
  load_trained_adapters=False,
27
  use_flash_attn=True,
28
  torch_dtype=None,
@@ -47,8 +51,12 @@ class XLMRobertaFlashConfig(PretrainedConfig):
47
  self.position_embedding_type = position_embedding_type
48
  self.use_cache = use_cache
49
  self.classifier_dropout = classifier_dropout
50
- self.num_loras = num_loras
51
  self.load_trained_adapters = load_trained_adapters
 
 
 
 
 
52
  self.use_flash_attn = use_flash_attn
53
  self.emb_pooler = emb_pooler
54
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
 
22
  position_embedding_type="absolute",
23
  use_cache=True,
24
  classifier_dropout=None,
25
+ lora_adaptations=None,
26
+ lora_rank=4,
27
+ lora_dropout_p=0.0,
28
+ lora_alpha=1,
29
+ lora_main_params_trainable=False,
30
  load_trained_adapters=False,
31
  use_flash_attn=True,
32
  torch_dtype=None,
 
51
  self.position_embedding_type = position_embedding_type
52
  self.use_cache = use_cache
53
  self.classifier_dropout = classifier_dropout
 
54
  self.load_trained_adapters = load_trained_adapters
55
+ self.lora_adaptations = lora_adaptations
56
+ self.lora_rank = lora_rank
57
+ self.lora_dropout_p = lora_dropout_p
58
+ self.lora_alpha = lora_alpha
59
+ self.lora_main_params_trainable = lora_main_params_trainable
60
  self.use_flash_attn = use_flash_attn
61
  self.emb_pooler = emb_pooler
62
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
modeling_lora.py CHANGED
@@ -1,22 +1,27 @@
1
  import math
2
  import os
 
3
  from functools import partial
4
- from typing import Iterator, Optional, Tuple, Union
5
 
 
6
  import torch
7
  import torch.nn.utils.parametrize as parametrize
8
  from torch import nn
9
  from torch.nn import Parameter
10
  from transformers import PretrainedConfig
11
 
12
- from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel, XLMRobertaFlashConfig
 
 
 
13
 
14
 
15
  def initialized_weights(
16
- shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
17
  ) -> torch.Tensor:
18
  weight_data = []
19
- for _ in range(num_adaptions):
20
  new_adaption = torch.zeros(shape)
21
  if init == "kaiming":
22
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
@@ -45,15 +50,16 @@ class LoRAParametrization(nn.Module):
45
  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
46
  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
47
  """
 
48
  def __init__(
49
  self,
50
  fan_in: int,
51
  fan_out: int,
52
  layer_type: str = "linear",
53
- num_adaptions: int = 1,
54
  rank: int = 4,
55
- lora_dropout_p: float = 0.0,
56
- lora_alpha: float = 1,
57
  ):
58
  super().__init__()
59
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
@@ -63,25 +69,23 @@ class LoRAParametrization(nn.Module):
63
 
64
  if layer_type == "linear":
65
  self.lora_A = nn.Parameter(
66
- initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
67
  )
68
- self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
69
  elif layer_type == "embedding":
70
- self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
71
  self.lora_B = nn.Parameter(
72
  initialized_weights(
73
- (rank, fan_out), num_adaptions=num_adaptions, init="normal"
74
  )
75
  )
76
  else:
77
  raise NotImplementedError
78
 
79
- self.lora_alpha, self.rank = lora_alpha, rank
80
- self.scaling = lora_alpha / rank
81
- self.lora_dropout = (
82
- nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
83
- )
84
- self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
85
  self.register_buffer(
86
  "lora_dropout_mask",
87
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
@@ -128,42 +132,52 @@ class LoRAParametrization(nn.Module):
128
  def from_linear(
129
  cls,
130
  layer: nn.Module,
131
- num_adaptions: int = 1,
132
- rank: int = 4,
133
- lora_dropout_p: float = 0.0,
134
- lora_alpha: int = 1,
135
  ):
136
  assert isinstance(layer, nn.Linear)
137
  fan_out, fan_in = layer.weight.shape
138
  return cls(
139
  fan_in,
140
  fan_out,
141
- num_adaptions=num_adaptions,
142
  layer_type="linear",
143
  rank=rank,
144
- lora_dropout_p=lora_dropout_p,
145
- lora_alpha=lora_alpha,
146
  )
147
 
148
  @classmethod
149
  def from_embedding(
150
- cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
 
 
 
 
 
151
  ):
152
  assert isinstance(layer, nn.Embedding)
153
  fan_in, fan_out = layer.weight.shape
154
  return cls(
155
  fan_in,
156
  fan_out,
157
- num_adaptions=num_adaptions,
158
  layer_type="embedding",
159
  rank=rank,
160
- lora_dropout_p=lora_dropout_p,
161
- lora_alpha=lora_alpha,
162
  )
163
 
164
  @classmethod
165
  def add_to_layer(
166
- cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
 
 
 
 
 
167
  ):
168
  if isinstance(layer, nn.Linear):
169
  parametrize.register_parametrization(
@@ -171,10 +185,10 @@ class LoRAParametrization(nn.Module):
171
  "weight",
172
  cls.from_linear(
173
  layer,
174
- num_adaptions=num_adaptions,
175
  rank=rank,
176
- lora_dropout_p=lora_dropout_p,
177
- lora_alpha=lora_alpha,
178
  ),
179
  )
180
  elif isinstance(layer, nn.Embedding):
@@ -183,10 +197,10 @@ class LoRAParametrization(nn.Module):
183
  "weight",
184
  cls.from_embedding(
185
  layer,
186
- num_adaptions=num_adaptions,
187
  rank=rank,
188
- lora_dropout_p=lora_dropout_p,
189
- lora_alpha=lora_alpha,
190
  ),
191
  )
192
 
@@ -195,30 +209,39 @@ class LoRAParametrization(nn.Module):
195
  if isinstance(layer, LoRAParametrization):
196
  layer.current_task = task_idx
197
 
198
- @staticmethod
199
- def merge_lora_into_layer(layer: nn.Module):
200
- if hasattr(layer, "parametrizations"):
201
- for attr_name in layer.parametrizations.keys():
202
- parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
203
-
204
 
205
- class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
206
- def __init__(self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None, add_pooling_layer=True):
 
 
 
207
  super().__init__(config)
208
 
209
- if roberta is None:
210
- self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
211
- else:
212
- self.roberta = roberta
213
-
214
- self._is_merged = False
215
- self._num_adaptions = config.num_loras
216
- self._register_lora(self._num_adaptions)
217
-
218
- self.main_params_trainable = False
 
 
 
 
 
 
 
 
 
 
 
 
219
  self._task_idx = None
220
- # By default, we select the first LoRA
221
- self.current_task = 0
222
 
223
  @property
224
  def main_params_trainable(self):
@@ -237,13 +260,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
237
  if "lora" not in name:
238
  param.requires_grad_(val)
239
 
240
- def merge_lora(self):
241
- """Merges currently selected LoRA into main weights."""
242
- if self._is_merged:
243
- raise Exception('LoRA has already been merged, cannot merge again')
244
- self._is_merged = True
245
- self.apply(LoRAParametrization.merge_lora_into_layer)
246
-
247
  @classmethod
248
  def from_pretrained(
249
  cls,
@@ -259,46 +275,52 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
259
  use_safetensors: bool = None,
260
  **kwargs,
261
  ):
262
- config = XLMRobertaFlashConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
 
263
  if config.load_trained_adapters:
264
  return super().from_pretrained(
265
- pretrained_model_name_or_path,
266
- *model_args,
267
- **kwargs
268
  )
269
  else:
270
- roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
271
- return cls(config, roberta=roberta)
 
272
 
273
- def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
274
  self.apply(
275
  partial(
276
  LoRAParametrization.add_to_layer,
277
- num_adaptions=num_adaptions,
278
  rank=rank,
279
- lora_dropout_p=lora_dropout_p,
280
- lora_alpha=lora_alpha,
281
  )
282
  )
283
 
284
  @property
285
  def current_task(self):
286
- """ Which LoRA is currently selected
287
  :return: Integer or None (when LoRA is disabled)
288
  """
289
  return self._task_idx
290
 
291
  @current_task.setter
292
- def current_task(self, task_idx: Union[None, int]):
293
  """Set the LoRA that is to be used.
294
  The LoRA is specified by `task_idx`, which may be an integer >= 0,
295
  indexing the available LoRAs. If it is None, no LoRA is used.
296
- :param task_idx: Which LoRA to use
297
  :return:
298
  """
299
- if self._is_merged:
300
- raise Exception('LoRA has been merged, cannot select new task')
301
- assert task_idx is None or 0 <= task_idx < self._num_adaptions
 
 
 
 
302
  if self._task_idx != task_idx:
303
  # In this case, we need to update the LoRAs everywhere
304
  self._task_idx = task_idx
@@ -306,10 +328,10 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
306
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
307
  )
308
 
309
- def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
310
- if current_task is None or current_task >= 0:
311
- self.current_task = current_task
312
- return self.roberta(*args, **kwargs)
313
 
314
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
315
  for _, param in self.named_parameters(recurse=recurse):
@@ -323,3 +345,32 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
323
  ):
324
  if "lora" in name or self.main_params_trainable:
325
  yield name, param
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import os
3
+ import warnings
4
  from functools import partial
5
+ from typing import Iterator, List, Optional, Tuple, Union
6
 
7
+ import numpy as np
8
  import torch
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
12
  from transformers import PretrainedConfig
13
 
14
+ from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
15
+
16
+
17
+ LORA_NO_UPDATE = '__lora_no_update__'
18
 
19
 
20
  def initialized_weights(
21
+ shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
22
  ) -> torch.Tensor:
23
  weight_data = []
24
+ for _ in range(num_adaptations):
25
  new_adaption = torch.zeros(shape)
26
  if init == "kaiming":
27
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
 
50
  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
51
  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
52
  """
53
+
54
  def __init__(
55
  self,
56
  fan_in: int,
57
  fan_out: int,
58
  layer_type: str = "linear",
59
+ num_adaptations: int = 1,
60
  rank: int = 4,
61
+ dropout_p: float = 0.0,
62
+ alpha: float = 1,
63
  ):
64
  super().__init__()
65
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
 
69
 
70
  if layer_type == "linear":
71
  self.lora_A = nn.Parameter(
72
+ initialized_weights((rank, fan_in), num_adaptations, init="kaiming")
73
  )
74
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptations, fan_out, rank)))
75
  elif layer_type == "embedding":
76
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptations, fan_in, rank)))
77
  self.lora_B = nn.Parameter(
78
  initialized_weights(
79
+ (rank, fan_out), num_adaptations=num_adaptations, init="normal"
80
  )
81
  )
82
  else:
83
  raise NotImplementedError
84
 
85
+ self.lora_alpha, self.rank = alpha, rank
86
+ self.scaling = alpha / rank
87
+ self.lora_dropout = nn.Dropout(p=dropout_p) if dropout_p > 0 else lambda x: x
88
+ self.dropout_fn = self._dropout if dropout_p > 0 else lambda x: x
 
 
89
  self.register_buffer(
90
  "lora_dropout_mask",
91
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
 
132
  def from_linear(
133
  cls,
134
  layer: nn.Module,
135
+ num_adaptations: int,
136
+ rank: int,
137
+ dropout_p: float,
138
+ alpha: float,
139
  ):
140
  assert isinstance(layer, nn.Linear)
141
  fan_out, fan_in = layer.weight.shape
142
  return cls(
143
  fan_in,
144
  fan_out,
145
+ num_adaptations=num_adaptations,
146
  layer_type="linear",
147
  rank=rank,
148
+ dropout_p=dropout_p,
149
+ alpha=alpha,
150
  )
151
 
152
  @classmethod
153
  def from_embedding(
154
+ cls,
155
+ layer: nn.Module,
156
+ num_adaptations: int,
157
+ rank: int,
158
+ dropout_p: float,
159
+ alpha: float,
160
  ):
161
  assert isinstance(layer, nn.Embedding)
162
  fan_in, fan_out = layer.weight.shape
163
  return cls(
164
  fan_in,
165
  fan_out,
166
+ num_adaptations=num_adaptations,
167
  layer_type="embedding",
168
  rank=rank,
169
+ dropout_p=dropout_p,
170
+ alpha=alpha,
171
  )
172
 
173
  @classmethod
174
  def add_to_layer(
175
+ cls,
176
+ layer: nn.Module,
177
+ num_adaptations: int,
178
+ rank: int,
179
+ dropout_p: float,
180
+ alpha: float,
181
  ):
182
  if isinstance(layer, nn.Linear):
183
  parametrize.register_parametrization(
 
185
  "weight",
186
  cls.from_linear(
187
  layer,
188
+ num_adaptations=num_adaptations,
189
  rank=rank,
190
+ dropout_p=dropout_p,
191
+ alpha=alpha,
192
  ),
193
  )
194
  elif isinstance(layer, nn.Embedding):
 
197
  "weight",
198
  cls.from_embedding(
199
  layer,
200
+ num_adaptations=num_adaptations,
201
  rank=rank,
202
+ dropout_p=dropout_p,
203
+ alpha=alpha,
204
  ),
205
  )
206
 
 
209
  if isinstance(layer, LoRAParametrization):
210
  layer.current_task = task_idx
211
 
 
 
 
 
 
 
212
 
213
+ class XLMRobertaLoRA(XLMRobertaModel):
214
+ def __init__(
215
+ self,
216
+ config: XLMRobertaFlashConfig,
217
+ ):
218
  super().__init__(config)
219
 
220
+ self._lora_adaptations = config.lora_adaptations
221
+ if (
222
+ not isinstance(self._lora_adaptations, list)
223
+ or len(self._lora_adaptations) < 1
224
+ ):
225
+ raise ValueError(
226
+ f'`lora_adaptations` must be a list and contain at least one element'
227
+ )
228
+ self._adaptation_map = {
229
+ name: idx for idx, name in enumerate(self._lora_adaptations)
230
+ }
231
+ self._rank = config.lora_rank
232
+ self._dropout_p = config.lora_dropout_p
233
+ self._alpha = config.lora_alpha
234
+
235
+ self._register_lora(
236
+ num_adaptations=self._num_adaptations,
237
+ rank=self._rank,
238
+ dropout_p=self._dropout_p,
239
+ alpha=self._alpha,
240
+ )
241
+ self.main_params_trainable = config.lora_main_params_trainable
242
  self._task_idx = None
243
+ # By default, disable LoRA until it's specified which adapter/task to use
244
+ self.current_task = None
245
 
246
  @property
247
  def main_params_trainable(self):
 
260
  if "lora" not in name:
261
  param.requires_grad_(val)
262
 
 
 
 
 
 
 
 
263
  @classmethod
264
  def from_pretrained(
265
  cls,
 
275
  use_safetensors: bool = None,
276
  **kwargs,
277
  ):
278
+ config = XLMRobertaFlashConfig.from_pretrained(
279
+ pretrained_model_name_or_path, *model_args, **kwargs
280
+ )
281
+
282
  if config.load_trained_adapters:
283
  return super().from_pretrained(
284
+ pretrained_model_name_or_path, *model_args, **kwargs
 
 
285
  )
286
  else:
287
+ dtype = config.torch_dtype if config.torch_dtype else torch.bfloat16
288
+ torch.set_default_dtype(dtype)
289
+ return cls(config)
290
 
291
+ def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
292
  self.apply(
293
  partial(
294
  LoRAParametrization.add_to_layer,
295
+ num_adaptations=num_adaptations,
296
  rank=rank,
297
+ dropout_p=dropout_p,
298
+ alpha=alpha,
299
  )
300
  )
301
 
302
  @property
303
  def current_task(self):
304
+ """Which LoRA is currently selected
305
  :return: Integer or None (when LoRA is disabled)
306
  """
307
  return self._task_idx
308
 
309
  @current_task.setter
310
+ def current_task(self, task_name: Union[None, str]):
311
  """Set the LoRA that is to be used.
312
  The LoRA is specified by `task_idx`, which may be an integer >= 0,
313
  indexing the available LoRAs. If it is None, no LoRA is used.
314
+ :param task_name: Which LoRA to use
315
  :return:
316
  """
317
+ if task_name and task_name not in self._lora_adaptations:
318
+ raise ValueError(
319
+ f"Unsupported task '{task_name}'. "
320
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
321
+ f"Alternatively, set `task` to `None` if you want to disable LoRA."
322
+ )
323
+ task_idx = self._adaptation_map[task_name] if task_name else None
324
  if self._task_idx != task_idx:
325
  # In this case, we need to update the LoRAs everywhere
326
  self._task_idx = task_idx
 
328
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
329
  )
330
 
331
+ def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
332
+ if task != LORA_NO_UPDATE:
333
+ self.current_task = task
334
+ return super().forward(*args, **kwargs)
335
 
336
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
337
  for _, param in self.named_parameters(recurse=recurse):
 
345
  ):
346
  if "lora" in name or self.main_params_trainable:
347
  yield name, param
348
+
349
+ @torch.inference_mode()
350
+ def encode(
351
+ self,
352
+ *args,
353
+ task: Union[str, None] = LORA_NO_UPDATE,
354
+ **kwargs,
355
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
356
+ """
357
+ Computes sentence embeddings
358
+
359
+ task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
360
+ Specifies the task for which the encoding is intended. This parameter controls the
361
+ use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
362
+ to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
363
+ existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
364
+ adapters are disabled, and the model reverts to its original, general-purpose weights.
365
+ If `task` is set to a specific LoRA adaptation, that adaptation is activated.
366
+ """
367
+ if task != LORA_NO_UPDATE:
368
+ if not task:
369
+ warnings.warn(
370
+ f"Task-specific embeddings are disabled. To enable, specify the `task` "
371
+ f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
372
+ category=UserWarning,
373
+ )
374
+ self.current_task = task
375
+
376
+ return super().encode(*args, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -1253,4 +1253,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1253
  logits=logits,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,
1256
- )
 
1253
  logits=logits,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,
1256
+ )