jupyterjazz Jackmin108 commited on
Commit
0f0bed6
1 Parent(s): 4434bf3

fix-adapter-masks (#32)

Browse files

- fix: adapter masks (934939f54211c85cc0a5f9891937c4015377c102)


Co-authored-by: Jack Min Ong <[email protected]>

Files changed (4) hide show
  1. block.py +1 -1
  2. mha.py +9 -9
  3. mlp.py +9 -9
  4. modeling_xlm_roberta.py +1 -1
block.py CHANGED
@@ -233,7 +233,7 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states, cu_adapter_mask=mixer_kwargs.get('cu_adapter_mask'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states, adapter_mask=mixer_kwargs.get('adapter_mask'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
mha.py CHANGED
@@ -590,7 +590,7 @@ class MHA(nn.Module):
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
- cu_adapter_mask=None,
594
  **kwargs,
595
  ):
596
  """
@@ -647,13 +647,13 @@ class MHA(nn.Module):
647
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
  assert x_kv is None and mixer_subset is None
649
 
650
- if cu_adapter_mask is not None:
651
- unique_tasks = torch.unique(cu_adapter_mask)
652
  qkv_dtype = next(self.Wqkv.parameters()).dtype
653
- qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
654
  dtype=qkv_dtype, device=x.device)
655
  for task_id in unique_tasks:
656
- task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
657
  task_tensor = x[task_indices]
658
  if not self.return_residual:
659
  task_qkv = self.Wqkv(task_tensor, task_id=task_id)
@@ -755,13 +755,13 @@ class MHA(nn.Module):
755
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
756
 
757
  inp = rearrange(context, "... h d -> ... (h d)")
758
- if cu_adapter_mask is not None:
759
- unique_tasks = torch.unique(cu_adapter_mask)
760
  out_dtype = next(self.out_proj.parameters()).dtype
761
- out = torch.empty(inp.shape[0], self.out_proj.out_features,
762
  dtype=out_dtype, device=inp.device)
763
  for task_id in unique_tasks:
764
- task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
765
  task_tensor = inp[task_indices]
766
  task_out = self.out_proj(task_tensor, task_id=task_id)
767
  out[task_indices] = task_out
 
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
+ adapter_mask=None,
594
  **kwargs,
595
  ):
596
  """
 
647
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
  assert x_kv is None and mixer_subset is None
649
 
650
+ if adapter_mask is not None:
651
+ unique_tasks = torch.unique(adapter_mask)
652
  qkv_dtype = next(self.Wqkv.parameters()).dtype
653
+ qkv = torch.empty(*x.shape[:-1], self.Wqkv.out_features,
654
  dtype=qkv_dtype, device=x.device)
655
  for task_id in unique_tasks:
656
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
657
  task_tensor = x[task_indices]
658
  if not self.return_residual:
659
  task_qkv = self.Wqkv(task_tensor, task_id=task_id)
 
755
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
756
 
757
  inp = rearrange(context, "... h d -> ... (h d)")
758
+ if adapter_mask is not None:
759
+ unique_tasks = torch.unique(adapter_mask)
760
  out_dtype = next(self.out_proj.parameters()).dtype
761
+ out = torch.empty(*inp.shape[:-1], self.out_proj.out_features,
762
  dtype=out_dtype, device=inp.device)
763
  for task_id in unique_tasks:
764
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
765
  task_tensor = inp[task_indices]
766
  task_out = self.out_proj(task_tensor, task_id=task_id)
767
  out[task_indices] = task_out
mlp.py CHANGED
@@ -47,14 +47,14 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x, cu_adapter_mask=None):
51
- if cu_adapter_mask is not None:
52
- unique_tasks = torch.unique(cu_adapter_mask)
53
  fc1_dtype = next(self.fc1.parameters()).dtype
54
- y = torch.empty(x.shape[0], self.fc1.out_features,
55
  dtype=fc1_dtype, device=x.device)
56
  for task_id in unique_tasks:
57
- task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
  task_tensor = x[task_indices]
59
  task_y = self.fc1(task_tensor, task_id=task_id)
60
  y[task_indices] = task_y
@@ -63,13 +63,13 @@ class Mlp(nn.Module):
63
 
64
  y = self.activation(y)
65
 
66
- if cu_adapter_mask is not None:
67
- unique_tasks = torch.unique(cu_adapter_mask)
68
  fc2_dtype = next(self.fc2.parameters()).dtype
69
- out = torch.empty(y.shape[0], self.fc2.out_features,
70
  dtype=fc2_dtype, device=y.device)
71
  for task_id in unique_tasks:
72
- task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
  task_tensor = y[task_indices]
74
  task_out = self.fc2(task_tensor, task_id=task_id)
75
  out[task_indices] = task_out
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, adapter_mask=None):
51
+ if adapter_mask is not None:
52
+ unique_tasks = torch.unique(adapter_mask)
53
  fc1_dtype = next(self.fc1.parameters()).dtype
54
+ y = torch.empty(*x.shape[:-1], self.fc1.out_features,
55
  dtype=fc1_dtype, device=x.device)
56
  for task_id in unique_tasks:
57
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
  task_tensor = x[task_indices]
59
  task_y = self.fc1(task_tensor, task_id=task_id)
60
  y[task_indices] = task_y
 
63
 
64
  y = self.activation(y)
65
 
66
+ if adapter_mask is not None:
67
+ unique_tasks = torch.unique(adapter_mask)
68
  fc2_dtype = next(self.fc2.parameters()).dtype
69
+ out = torch.empty(*y.shape[:-1], self.fc2.out_features,
70
  dtype=fc2_dtype, device=y.device)
71
  for task_id in unique_tasks:
72
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
  task_tensor = y[task_indices]
74
  task_out = self.fc2(task_tensor, task_id=task_id)
75
  out[task_indices] = task_out
modeling_xlm_roberta.py CHANGED
@@ -230,7 +230,7 @@ class XLMRobertaEncoder(nn.Module):
230
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
231
  hidden_states, key_padding_mask, adapter_mask
232
  )
233
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "cu_adapter_mask": cu_adapter_mask}
234
 
235
  if subset_mask is None:
236
  for layer in self.layers:
 
230
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
231
  hidden_states, key_padding_mask, adapter_mask
232
  )
233
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "adapter_mask": cu_adapter_mask}
234
 
235
  if subset_mask is None:
236
  for layer in self.layers: