Jackmin108 commited on
Commit
8f83a35
2 Parent(s): 5f8e4b6 57dbe22

merge recent changes

Browse files

Signed-off-by: Meow <[email protected]>

configuration_xlm_roberta.py CHANGED
@@ -5,6 +5,9 @@ from transformers import PretrainedConfig
5
 
6
 
7
  class XLMRobertaFlashConfig(PretrainedConfig):
 
 
 
8
  def __init__(
9
  self,
10
  vocab_size: int = 250002,
@@ -25,6 +28,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
25
  position_embedding_type: str = "rotary",
26
  rotary_emb_base: float = 10000.0,
27
  use_cache: bool = True,
 
28
  classifier_dropout: Optional[float] = None,
29
  lora_adaptations: Optional[List[str]] = None,
30
  task_instructions: Optional[Dict[str, str]] = None,
@@ -62,6 +66,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
62
  position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
63
  rotary_emb_base (float): Base for rotary embeddings.
64
  use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
 
65
  classifier_dropout (Optional[float]): The dropout ratio for the classification head.
66
  lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
67
  lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
@@ -100,6 +105,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
100
  self.position_embedding_type = position_embedding_type
101
  self.rotary_emb_base = rotary_emb_base
102
  self.use_cache = use_cache
 
103
  self.classifier_dropout = classifier_dropout
104
  self.load_trained_adapters = load_trained_adapters
105
  self.lora_adaptations = lora_adaptations
 
5
 
6
 
7
  class XLMRobertaFlashConfig(PretrainedConfig):
8
+
9
+ model_type = "xlm-roberta"
10
+
11
  def __init__(
12
  self,
13
  vocab_size: int = 250002,
 
28
  position_embedding_type: str = "rotary",
29
  rotary_emb_base: float = 10000.0,
30
  use_cache: bool = True,
31
+ use_reentrant: bool = False,
32
  classifier_dropout: Optional[float] = None,
33
  lora_adaptations: Optional[List[str]] = None,
34
  task_instructions: Optional[Dict[str, str]] = None,
 
66
  position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
67
  rotary_emb_base (float): Base for rotary embeddings.
68
  use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
69
+ use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
70
  classifier_dropout (Optional[float]): The dropout ratio for the classification head.
71
  lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
72
  lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
 
105
  self.position_embedding_type = position_embedding_type
106
  self.rotary_emb_base = rotary_emb_base
107
  self.use_cache = use_cache
108
+ self.use_reentrant = use_reentrant
109
  self.classifier_dropout = classifier_dropout
110
  self.load_trained_adapters = load_trained_adapters
111
  self.lora_adaptations = lora_adaptations
mha.py CHANGED
@@ -463,6 +463,7 @@ class MHA(nn.Module):
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
 
466
  )
467
 
468
  if fused_bias_fc and FusedDense is None:
 
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
466
+ use_flash_attn=use_flash_attn,
467
  )
468
 
469
  if fused_bias_fc and FusedDense is None:
modeling_lora.py CHANGED
@@ -11,11 +11,9 @@ from torch.nn import Parameter
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
- from .modeling_xlm_roberta import (
15
- XLMRobertaFlashConfig,
16
- XLMRobertaModel,
17
- XLMRobertaPreTrainedModel,
18
- )
19
 
20
 
21
  def initialized_weights(
@@ -328,16 +326,13 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
328
  use_safetensors: bool = None,
329
  **kwargs,
330
  ):
331
- config = XLMRobertaFlashConfig.from_pretrained(
332
- pretrained_model_name_or_path, *model_args, **kwargs
333
- )
334
  if config.load_trained_adapters: # checkpoint already contains LoRA adapters
335
  return super().from_pretrained(
336
- pretrained_model_name_or_path, *model_args, **kwargs
337
  )
338
  else: # initializing new adapters
339
  roberta = XLMRobertaModel.from_pretrained(
340
- pretrained_model_name_or_path, *model_args, **kwargs
341
  )
342
  return cls(config, roberta=roberta)
343
 
@@ -399,14 +394,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
399
  adapter_mask = torch.full(
400
  (num_examples,), task_id, dtype=torch.int32, device=self.device
401
  )
402
- if task_type in ["query", "passage"]:
403
- if isinstance(sentences, str):
404
- sentences = self._task_instructions[task_type] + " " + sentences
405
- else:
406
- sentences = [
407
- self._task_instructions[task_type] + " " + sentence
408
- for sentence in sentences
409
- ]
410
  return self.roberta.encode(
411
  sentences, *args, adapter_mask=adapter_mask, **kwargs
412
  )
 
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
+ from .rotary import RotaryEmbedding
15
+ from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
16
+ XLMRobertaPreTrainedModel)
 
 
17
 
18
 
19
  def initialized_weights(
 
326
  use_safetensors: bool = None,
327
  **kwargs,
328
  ):
 
 
 
329
  if config.load_trained_adapters: # checkpoint already contains LoRA adapters
330
  return super().from_pretrained(
331
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
332
  )
333
  else: # initializing new adapters
334
  roberta = XLMRobertaModel.from_pretrained(
335
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
336
  )
337
  return cls(config, roberta=roberta)
338
 
 
394
  adapter_mask = torch.full(
395
  (num_examples,), task_id, dtype=torch.int32, device=self.device
396
  )
 
 
 
 
 
 
 
 
397
  return self.roberta.encode(
398
  sentences, *args, adapter_mask=adapter_mask, **kwargs
399
  )
modeling_xlm_roberta.py CHANGED
@@ -30,6 +30,7 @@ from transformers.models.bert.modeling_bert import (
30
  from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
  XLMRobertaLMHead
32
 
 
33
  from .block import Block
34
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
35
  from .embedding import XLMRobertaEmbeddings
@@ -63,9 +64,7 @@ logger = logging.getLogger(__name__)
63
 
64
 
65
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
66
- if not getattr(config, "use_flash_attn", False):
67
- return False
68
- if not torch.cuda.is_available():
69
  return False
70
  if importlib.util.find_spec("flash_attn") is None:
71
  logger.warning(
@@ -181,6 +180,7 @@ class XLMRobertaEncoder(nn.Module):
181
  def __init__(self, config: XLMRobertaFlashConfig):
182
  super().__init__()
183
  self.use_flash_attn = get_use_flash_attn(config)
 
184
  self.layers = nn.ModuleList(
185
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
  )
@@ -210,7 +210,7 @@ class XLMRobertaEncoder(nn.Module):
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  layer,
212
  hidden_states,
213
- use_reentrant=False,
214
  mixer_kwargs=mixer_kwargs,
215
  )
216
  else:
@@ -234,7 +234,7 @@ class XLMRobertaEncoder(nn.Module):
234
  hidden_states = torch.utils.checkpoint.checkpoint(
235
  layer,
236
  hidden_states,
237
- use_reentrant=False,
238
  mixer_kwargs=mixer_kwargs,
239
  )
240
  else:
@@ -246,7 +246,7 @@ class XLMRobertaEncoder(nn.Module):
246
  hidden_states = torch.utils.checkpoint.checkpoint(
247
  layer,
248
  hidden_states,
249
- use_reentrant=False,
250
  mixer_kwargs=mixer_kwargs,
251
  )
252
  else:
@@ -284,7 +284,7 @@ class XLMRobertaEncoder(nn.Module):
284
  torch.utils.checkpoint.checkpoint(
285
  self.layers[-1],
286
  hidden_states_subset,
287
- use_reentrant=False,
288
  mixer_kwargs=mixer_kwargs,
289
  )
290
  else:
 
30
  from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
  XLMRobertaLMHead
32
 
33
+ from .rotary import RotaryEmbedding
34
  from .block import Block
35
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
36
  from .embedding import XLMRobertaEmbeddings
 
64
 
65
 
66
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
67
+ if not getattr(config, "use_flash_attn", False) or not torch.cuda.is_available():
 
 
68
  return False
69
  if importlib.util.find_spec("flash_attn") is None:
70
  logger.warning(
 
180
  def __init__(self, config: XLMRobertaFlashConfig):
181
  super().__init__()
182
  self.use_flash_attn = get_use_flash_attn(config)
183
+ self.use_reentrant = config.use_reentrant
184
  self.layers = nn.ModuleList(
185
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
  )
 
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  layer,
212
  hidden_states,
213
+ use_reentrant=self.use_reentrant,
214
  mixer_kwargs=mixer_kwargs,
215
  )
216
  else:
 
234
  hidden_states = torch.utils.checkpoint.checkpoint(
235
  layer,
236
  hidden_states,
237
+ use_reentrant=self.use_reentrant,
238
  mixer_kwargs=mixer_kwargs,
239
  )
240
  else:
 
246
  hidden_states = torch.utils.checkpoint.checkpoint(
247
  layer,
248
  hidden_states,
249
+ use_reentrant=self.use_reentrant,
250
  mixer_kwargs=mixer_kwargs,
251
  )
252
  else:
 
284
  torch.utils.checkpoint.checkpoint(
285
  self.layers[-1],
286
  hidden_states_subset,
287
+ use_reentrant=self.use_reentrant,
288
  mixer_kwargs=mixer_kwargs,
289
  )
290
  else:
rotary.py CHANGED
@@ -4,7 +4,6 @@
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
7
- import math
8
  from typing import Optional, Tuple, Union
9
 
10
  import torch
@@ -16,7 +15,10 @@ if torch.cuda.is_available():
16
  except ImportError:
17
 
18
  def apply_rotary(*args, **kwargs):
19
- raise RuntimeError("RoPE requires flash-attention to be installed")
 
 
 
20
 
21
 
22
  def rotate_half(x, interleaved=False):
@@ -169,12 +171,13 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
169
  seqlen_offsets: Union[int, torch.Tensor] = 0,
170
  cu_seqlens: Optional[torch.Tensor] = None,
171
  max_seqlen: Optional[int] = None,
 
172
  ):
173
  # batch, seqlen, three, nheads, headdim = qkv.shape
174
  assert qkv.shape[-3] == 3
175
  if cos_k is None and sin_k is None and qkv.is_contiguous():
176
 
177
- if torch.cuda.is_available():
178
  # Call 1 kernel instead of 2 kernels
179
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
180
  # dimensions, we get the same tensor
@@ -288,7 +291,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
288
  cu_seqlens=cu_seqlens,
289
  max_seqlen=ctx.max_seqlen,
290
  )
291
- return dqkv, None, None, None, None, None, None, None, None
292
 
293
 
294
  def apply_rotary_emb_qkv_(
@@ -301,6 +304,7 @@ def apply_rotary_emb_qkv_(
301
  seqlen_offsets: Union[int, torch.Tensor] = 0,
302
  cu_seqlens: Optional[torch.Tensor] = None,
303
  max_seqlen: Optional[int] = None,
 
304
  ):
305
  """
306
  Arguments:
@@ -321,7 +325,7 @@ def apply_rotary_emb_qkv_(
321
  Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
322
  """
323
  return ApplyRotaryEmbQKV_.apply(
324
- qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
325
  )
326
 
327
 
@@ -443,6 +447,7 @@ class RotaryEmbedding(torch.nn.Module):
443
  scale_base=None,
444
  pos_idx_in_fp32=True,
445
  device=None,
 
446
  ):
447
  """
448
  interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
@@ -462,6 +467,7 @@ class RotaryEmbedding(torch.nn.Module):
462
  self.dim = dim
463
  self._base = float(base)
464
  self.pos_idx_in_fp32 = pos_idx_in_fp32
 
465
  # Generate and save the inverse frequency buffer (non trainable)
466
  inv_freq = self._compute_inv_freq(device)
467
  self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -588,6 +594,7 @@ class RotaryEmbedding(torch.nn.Module):
588
  seqlen_offsets=seqlen_offset,
589
  cu_seqlens=cu_seqlens,
590
  max_seqlen=max_seqlen,
 
591
  )
592
  else:
593
  return apply_rotary_emb_qkv_(
@@ -600,6 +607,7 @@ class RotaryEmbedding(torch.nn.Module):
600
  seqlen_offsets=seqlen_offset,
601
  cu_seqlens=cu_seqlens,
602
  max_seqlen=max_seqlen,
 
603
  )
604
  else:
605
  q = qkv
 
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
 
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
 
15
  except ImportError:
16
 
17
  def apply_rotary(*args, **kwargs):
18
+ raise RuntimeError(
19
+ "FlashAttention is not installed. To proceed with training, please install FlashAttention. "
20
+ "For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
21
+ )
22
 
23
 
24
  def rotate_half(x, interleaved=False):
 
171
  seqlen_offsets: Union[int, torch.Tensor] = 0,
172
  cu_seqlens: Optional[torch.Tensor] = None,
173
  max_seqlen: Optional[int] = None,
174
+ use_flash_attn: bool = True,
175
  ):
176
  # batch, seqlen, three, nheads, headdim = qkv.shape
177
  assert qkv.shape[-3] == 3
178
  if cos_k is None and sin_k is None and qkv.is_contiguous():
179
 
180
+ if use_flash_attn:
181
  # Call 1 kernel instead of 2 kernels
182
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
183
  # dimensions, we get the same tensor
 
291
  cu_seqlens=cu_seqlens,
292
  max_seqlen=ctx.max_seqlen,
293
  )
294
+ return dqkv, None, None, None, None, None, None, None, None, None
295
 
296
 
297
  def apply_rotary_emb_qkv_(
 
304
  seqlen_offsets: Union[int, torch.Tensor] = 0,
305
  cu_seqlens: Optional[torch.Tensor] = None,
306
  max_seqlen: Optional[int] = None,
307
+ use_flash_attn=True,
308
  ):
309
  """
310
  Arguments:
 
325
  Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
326
  """
327
  return ApplyRotaryEmbQKV_.apply(
328
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
329
  )
330
 
331
 
 
447
  scale_base=None,
448
  pos_idx_in_fp32=True,
449
  device=None,
450
+ use_flash_attn=True,
451
  ):
452
  """
453
  interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
 
467
  self.dim = dim
468
  self._base = float(base)
469
  self.pos_idx_in_fp32 = pos_idx_in_fp32
470
+ self.use_flash_attn = use_flash_attn
471
  # Generate and save the inverse frequency buffer (non trainable)
472
  inv_freq = self._compute_inv_freq(device)
473
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
594
  seqlen_offsets=seqlen_offset,
595
  cu_seqlens=cu_seqlens,
596
  max_seqlen=max_seqlen,
597
+ use_flash_attn=self.use_flash_attn,
598
  )
599
  else:
600
  return apply_rotary_emb_qkv_(
 
607
  seqlen_offsets=seqlen_offset,
608
  cu_seqlens=cu_seqlens,
609
  max_seqlen=max_seqlen,
610
+ use_flash_attn=self.use_flash_attn,
611
  )
612
  else:
613
  q = qkv