Guanzheng commited on
Commit
1342086
1 Parent(s): efe2fd9

Update modeling_phi2_clex.py

Browse files
Files changed (1) hide show
  1. modeling_phi2_clex.py +4 -5
modeling_phi2_clex.py CHANGED
@@ -59,7 +59,10 @@ logger = logging.get_logger(__name__)
59
  _CHECKPOINT_FOR_DOC = "microsoft/phi-2"
60
  _CONFIG_FOR_DOC = "CLEXPhiConfig"
61
 
62
-
 
 
 
63
 
64
 
65
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -373,10 +376,6 @@ class PhiAttention(nn.Module):
373
  # [batch_size, seq_length, num_heads, head_dim]
374
  query_states = torch.cat((query_rot, query_pass), dim=-1)
375
  key_states = torch.cat((key_rot, key_pass), dim=-1)
376
- rotary_dim = int(self.partial_rotary_factor * self.head_dim)
377
- if past_key_value is not None:
378
- cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": rotary_dim}
379
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
380
 
381
  key_states = repeat_kv(key_states, self.num_key_value_groups)
382
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
59
  _CHECKPOINT_FOR_DOC = "microsoft/phi-2"
60
  _CONFIG_FOR_DOC = "CLEXPhiConfig"
61
 
62
+ PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "microsoft/phi-2",
64
+ # See all Phi models at https://huggingface.co/models?filter=phi
65
+ ]
66
 
67
 
68
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
376
  # [batch_size, seq_length, num_heads, head_dim]
377
  query_states = torch.cat((query_rot, query_pass), dim=-1)
378
  key_states = torch.cat((key_rot, key_pass), dim=-1)
 
 
 
 
379
 
380
  key_states = repeat_kv(key_states, self.num_key_value_groups)
381
  value_states = repeat_kv(value_states, self.num_key_value_groups)