Update modeling_baichuan.py

#12
by JaheimLee - opened
Files changed (1) hide show
  1. modeling_baichuan.py +10 -7
modeling_baichuan.py CHANGED
@@ -30,7 +30,8 @@ except ImportError:
30
  logger.warning(
31
  "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
32
  )
33
-
 
34
 
35
  def _get_interleave(n):
36
  def _get_interleave_power_of_2(n):
@@ -173,12 +174,14 @@ class BaichuanAttention(torch.nn.Module):
173
  past_key_value = (key_states, value_states) if use_cache else None
174
  if xops is not None and self.training:
175
  attn_weights = None
176
- # query_states = query_states.transpose(1, 2)
177
- # key_states = key_states.transpose(1, 2)
178
- # value_states = value_states.transpose(1, 2)
179
- # attn_output = xops.memory_efficient_attention(
180
- # query_states, key_states, value_states, attn_bias=attention_mask
181
- # )
 
 
182
  with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
183
  attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
184
  attn_output = attn_output.transpose(1, 2)
 
30
  logger.warning(
31
  "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
32
  )
33
+
34
+ pytorch_major_version = int(torch.__version__.split('.')[0])
35
 
36
  def _get_interleave(n):
37
  def _get_interleave_power_of_2(n):
 
174
  past_key_value = (key_states, value_states) if use_cache else None
175
  if xops is not None and self.training:
176
  attn_weights = None
177
+ query_states = query_states.transpose(1, 2)
178
+ key_states = key_states.transpose(1, 2)
179
+ value_states = value_states.transpose(1, 2)
180
+ attn_output = xops.memory_efficient_attention(
181
+ query_states, key_states, value_states, attn_bias=attention_mask
182
+ )
183
+ elif pytorch_major_version >= 2:
184
+ attn_weights = None
185
  with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
186
  attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
187
  attn_output = attn_output.transpose(1, 2)