Jackmin108 commited on
Commit
b490388
1 Parent(s): 7c4a80c

feat: support alibi

Browse files

Signed-off-by: Meow <[email protected]>

Files changed (2) hide show
  1. embedding.py +1 -1
  2. modeling_xlm_roberta.py +2 -1
embedding.py CHANGED
@@ -50,7 +50,7 @@ class XLMRobertaEmbeddings(nn.Module):
50
  embeddings = self.word_embeddings(input_ids)
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
- position_ids =create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
  # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
  position_embeddings = self.position_embeddings(position_ids)
56
  embeddings = embeddings + position_embeddings
 
50
  embeddings = self.word_embeddings(input_ids)
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
+ position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
  # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
  position_embeddings = self.position_embeddings(position_ids)
56
  embeddings = embeddings + position_embeddings
modeling_xlm_roberta.py CHANGED
@@ -109,6 +109,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
109
  fused_bias_fc=fused_bias_fc,
110
  use_flash_attn=use_flash_attn,
111
  return_residual=return_residual,
 
112
  **rotary_kwargs,
113
  )
114
  return mixer_cls
@@ -429,7 +430,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
429
  self.embeddings = XLMRobertaEmbeddings(
430
  config.hidden_size,
431
  config.vocab_size,
432
- config.max_position_embeddings,
433
  config.type_vocab_size,
434
  padding_idx=config.pad_token_id,
435
  )
 
109
  fused_bias_fc=fused_bias_fc,
110
  use_flash_attn=use_flash_attn,
111
  return_residual=return_residual,
112
+ use_alibi=config.position_embedding_type == 'alibi',
113
  **rotary_kwargs,
114
  )
115
  return mixer_cls
 
430
  self.embeddings = XLMRobertaEmbeddings(
431
  config.hidden_size,
432
  config.vocab_size,
433
+ config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
434
  config.type_vocab_size,
435
  padding_idx=config.pad_token_id,
436
  )