walebadr commited on
Commit
00d9481
1 Parent(s): 9e4e0fc

Create modeling_decilm.py

Browse files
Files changed (1) hide show
  1. modeling_decilm.py +316 -0
modeling_decilm.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright and license in the repo.
3
+ """ PyTorch DeciLM model."""
4
+ from .version_check import check_transformers_version
5
+
6
+ check_transformers_version()
7
+
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
15
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
16
+
17
+ from .configuration_decilm import DeciLMConfig
18
+ from .transformers_v4_35_2__modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
19
+ from .transformers_v4_35_2__modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
20
+ repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, \
21
+ BaseModelOutputWithPast, LLAMA_INPUTS_DOCSTRING
22
+
23
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES["deci"] = "DeciLMForCausalLM"
24
+ _CONFIG_FOR_DOC = "DeciLMConfig"
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class DeciLMAttention(LlamaAttention):
29
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
30
+
31
+ def __init__(self, config: DeciLMConfig, layer_idx: int):
32
+ nn.Module.__init__(self)
33
+ self.config = config
34
+ self.hidden_size = config.hidden_size
35
+ self.num_heads = config.num_attention_heads
36
+ self.head_dim = self.hidden_size // self.num_heads
37
+ self.layer_idx = layer_idx
38
+ self.num_key_value_heads = config.num_key_value_heads_per_layer[layer_idx]
39
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
40
+ self.pretraining_tp = config.pretraining_tp
41
+ self.max_position_embeddings = config.max_position_embeddings
42
+ self.rope_theta = getattr(config, 'rope_theta', None)
43
+
44
+ if (self.head_dim * self.num_heads) != self.hidden_size:
45
+ raise ValueError(
46
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
47
+ f" and `num_heads`: {self.num_heads})."
48
+ )
49
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
50
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
51
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
52
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
53
+
54
+ self._init_rope()
55
+
56
+ def forward(
57
+ self,
58
+ hidden_states: torch.Tensor,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
62
+ output_attentions: bool = False,
63
+ use_cache: bool = False,
64
+ **kwargs,
65
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
66
+ bsz, q_len, _ = hidden_states.size()
67
+ is_decode = past_key_value is not None
68
+ if self.pretraining_tp > 1:
69
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
70
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
71
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
72
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
73
+
74
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
75
+ query_states = torch.cat(query_states, dim=-1)
76
+
77
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
78
+ key_states = torch.cat(key_states, dim=-1)
79
+
80
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
81
+ value_states = torch.cat(value_states, dim=-1)
82
+
83
+ else:
84
+ query_states = self.q_proj(hidden_states)
85
+ key_states = self.k_proj(hidden_states)
86
+ value_states = self.v_proj(hidden_states)
87
+
88
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
89
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
90
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
91
+
92
+ kv_seq_len = key_states.shape[-2]
93
+ if past_key_value is not None:
94
+ kv_seq_len += past_key_value[0].shape[-2]
95
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
96
+
97
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
98
+
99
+ if past_key_value is not None:
100
+ # reuse k, v, self_attention
101
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
102
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
103
+
104
+ past_key_value = (key_states, value_states) if use_cache else None
105
+
106
+ # repeat k/v heads if n_kv_heads < n_heads
107
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
108
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
109
+ if is_decode:
110
+ with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=True,
111
+ enable_mem_efficient=attention_mask is None):
112
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
113
+ is_causal=False,
114
+ attn_mask=attention_mask)
115
+ attn_output = attn_output.contiguous().view(bsz, q_len, self.hidden_size)
116
+
117
+ else:
118
+ with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
119
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
120
+ is_causal=attention_mask is None,
121
+ attn_mask=attention_mask)
122
+
123
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
124
+ raise ValueError(
125
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
126
+ f" {attn_output.size()}"
127
+ )
128
+
129
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
130
+
131
+ if self.pretraining_tp > 1:
132
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
133
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
134
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
135
+ else:
136
+ attn_output = self.o_proj(attn_output)
137
+
138
+ attn_weights = None
139
+
140
+ return attn_output, attn_weights, past_key_value
141
+
142
+
143
+ class DeciLMDecoderLayer(LlamaDecoderLayer):
144
+ def __init__(self, config: DeciLMConfig, layer_idx: int):
145
+ nn.Module.__init__(self)
146
+ self.hidden_size = config.hidden_size
147
+ self.layer_idx = layer_idx
148
+ self.self_attn = DeciLMAttention(config=config, layer_idx=layer_idx)
149
+ self.mlp = LlamaMLP(config)
150
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
151
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
152
+
153
+
154
+ @add_start_docstrings(
155
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
156
+ LLAMA_START_DOCSTRING,
157
+ )
158
+ class DeciLMPreTrainedModel(LlamaPreTrainedModel):
159
+ config_class = DeciLMConfig
160
+ _no_split_modules = ["DeciLMDecoderLayer"]
161
+ _keys_to_ignore_on_load_missing = ["self_attn.rotary_emb.inv_freq"]
162
+
163
+
164
+ @add_start_docstrings(
165
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
166
+ LLAMA_START_DOCSTRING,
167
+ )
168
+ class DeciLMModel(LlamaModel, DeciLMPreTrainedModel):
169
+ """
170
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`]
171
+ Args:
172
+ config: DeciLMConfig
173
+ """
174
+
175
+ def __init__(self, config: DeciLMConfig):
176
+ DeciLMPreTrainedModel.__init__(self, config)
177
+ self.padding_idx = config.pad_token_id
178
+ self.vocab_size = config.vocab_size
179
+
180
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
181
+ self.layers = nn.ModuleList([DeciLMDecoderLayer(config, layer_idx) for layer_idx
182
+ in range(config.num_hidden_layers)])
183
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
184
+
185
+ self.gradient_checkpointing = False
186
+ # Initialize weights and apply final processing
187
+ self.post_init()
188
+
189
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
190
+ def forward(
191
+ self,
192
+ input_ids: torch.LongTensor = None,
193
+ attention_mask: Optional[torch.Tensor] = None,
194
+ position_ids: Optional[torch.LongTensor] = None,
195
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
196
+ inputs_embeds: Optional[torch.FloatTensor] = None,
197
+ use_cache: Optional[bool] = None,
198
+ output_attentions: Optional[bool] = None,
199
+ output_hidden_states: Optional[bool] = None,
200
+ return_dict: Optional[bool] = None,
201
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
202
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
203
+ output_hidden_states = (
204
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ )
206
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
207
+
208
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
209
+
210
+ # retrieve input_ids and inputs_embeds
211
+ if input_ids is not None and inputs_embeds is not None:
212
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
213
+ elif input_ids is not None:
214
+ batch_size, seq_length = input_ids.shape[:2]
215
+ elif inputs_embeds is not None:
216
+ batch_size, seq_length = inputs_embeds.shape[:2]
217
+ else:
218
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
219
+
220
+ past_key_values_length = 0
221
+ if past_key_values is not None:
222
+ past_key_values_length = past_key_values[0][0].shape[2]
223
+
224
+ if position_ids is None:
225
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
226
+ position_ids = torch.arange(
227
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
228
+ )
229
+ position_ids = position_ids.unsqueeze(0)
230
+
231
+ if inputs_embeds is None:
232
+ inputs_embeds = self.embed_tokens(input_ids)
233
+
234
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
235
+ if attention_mask is not None:
236
+ # 4d mask is passed through the layers
237
+ attention_mask = _prepare_4d_causal_attention_mask(
238
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
239
+ )
240
+
241
+ # embed positions
242
+ hidden_states = inputs_embeds
243
+
244
+ if self.gradient_checkpointing and self.training:
245
+ if use_cache:
246
+ logger.warning_once(
247
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
248
+ )
249
+ use_cache = False
250
+
251
+ # decoder layers
252
+ all_hidden_states = () if output_hidden_states else None
253
+ all_self_attns = () if output_attentions else None
254
+ next_decoder_cache = () if use_cache else None
255
+
256
+ for idx, decoder_layer in enumerate(self.layers):
257
+ if output_hidden_states:
258
+ all_hidden_states += (hidden_states,)
259
+
260
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
261
+
262
+ if self.gradient_checkpointing and self.training:
263
+ layer_outputs = self._gradient_checkpointing_func(
264
+ decoder_layer.__call__,
265
+ hidden_states,
266
+ attention_mask,
267
+ position_ids,
268
+ past_key_value,
269
+ output_attentions,
270
+ use_cache,
271
+ )
272
+ else:
273
+ layer_outputs = decoder_layer(
274
+ hidden_states,
275
+ attention_mask=attention_mask,
276
+ position_ids=position_ids,
277
+ past_key_value=past_key_value,
278
+ output_attentions=output_attentions,
279
+ use_cache=use_cache,
280
+ )
281
+
282
+ hidden_states = layer_outputs[0]
283
+
284
+ if use_cache:
285
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
286
+
287
+ if output_attentions:
288
+ all_self_attns += (layer_outputs[1],)
289
+
290
+ hidden_states = self.norm(hidden_states)
291
+
292
+ # add hidden states from the last decoder layer
293
+ if output_hidden_states:
294
+ all_hidden_states += (hidden_states,)
295
+
296
+ next_cache = next_decoder_cache if use_cache else None
297
+ if not return_dict:
298
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
299
+ return BaseModelOutputWithPast(
300
+ last_hidden_state=hidden_states,
301
+ past_key_values=next_cache,
302
+ hidden_states=all_hidden_states,
303
+ attentions=all_self_attns,
304
+ )
305
+
306
+
307
+ class DeciLMForCausalLM(LlamaForCausalLM, DeciLMPreTrainedModel):
308
+ def __init__(self, config):
309
+ DeciLMPreTrainedModel.__init__(self, config)
310
+ self.model = DeciLMModel(config)
311
+ self.pretraining_tp = config.pretraining_tp
312
+ self.vocab_size = config.vocab_size
313
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
314
+
315
+ # Initialize weights and apply final processing
316
+ self.post_init()