ahatamiz commited on
Commit
54acb42
1 Parent(s): 166dfd4

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MambaVisionModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mambavision.MambaVisionConfig",
7
+ "AutoModel": "modeling_mambavision.MambaVisionModel"
8
+ },
9
+ "depths": [
10
+ 3,
11
+ 3,
12
+ 10,
13
+ 5
14
+ ],
15
+ "dim": 128,
16
+ "drop_path_rate": 0.3,
17
+ "in_dim": 64,
18
+ "layer_scale": 1e-05,
19
+ "layer_scale_conv": null,
20
+ "mlp_ratio": 4,
21
+ "model_type": "mambavision",
22
+ "num_heads": [
23
+ 2,
24
+ 4,
25
+ 8,
26
+ 16
27
+ ],
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.36.2",
30
+ "window_size": [
31
+ 8,
32
+ 8,
33
+ 14,
34
+ 7
35
+ ]
36
+ }
configuration_mambavision.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MambaVisionConfig(PretrainedConfig):
4
+ model_type = "mambavision"
5
+
6
+ def __init__(
7
+ self,
8
+ depths=[3, 3, 10, 5],
9
+ num_heads=[2, 4, 8, 16],
10
+ window_size=[8, 8, 14, 7],
11
+ dim=128,
12
+ in_dim=64,
13
+ mlp_ratio=4,
14
+ drop_path_rate=0.3,
15
+ layer_scale=1e-5,
16
+ layer_scale_conv=None,
17
+ **kwargs,
18
+ ):
19
+ self.depths = depths
20
+ self.num_heads = num_heads
21
+ self.window_size = window_size
22
+ self.dim = dim
23
+ self.in_dim = in_dim
24
+ self.mlp_ratio = mlp_ratio
25
+ self.drop_path_rate = drop_path_rate
26
+ self.layer_scale=layer_scale
27
+ self.layer_scale_conv=layer_scale_conv
28
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a6608c75914925868fbbe4e0386af8855bc9e6f4e68ccd58958bcb9f1dfcda9
3
+ size 390807656
modeling_mambavision.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from timm.models.registry import register_model
15
+ import math
16
+ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
17
+ from timm.models._builder import resolve_pretrained_cfg
18
+ try:
19
+ from timm.models._builder import _update_default_kwargs as update_args
20
+ except:
21
+ from timm.models._builder import _update_default_model_kwargs as update_args
22
+ from timm.models.vision_transformer import Mlp, PatchEmbed
23
+ from timm.models.layers import DropPath, trunc_normal_
24
+ from timm.models.registry import register_model
25
+ import torch.nn.functional as F
26
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
27
+ from einops import rearrange, repeat
28
+
29
+ from transformers import PreTrainedModel
30
+
31
+ from configuration_mambavision import MambaVisionConfig
32
+
33
+
34
+ def _cfg(url='', **kwargs):
35
+ return {'url': url,
36
+ 'num_classes': 1000,
37
+ 'input_size': (3, 224, 224),
38
+ 'pool_size': None,
39
+ 'crop_pct': 0.875,
40
+ 'interpolation': 'bicubic',
41
+ 'fixed_input_size': True,
42
+ 'mean': (0.485, 0.456, 0.406),
43
+ 'std': (0.229, 0.224, 0.225),
44
+ **kwargs
45
+ }
46
+
47
+
48
+ default_cfgs = {
49
+ 'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
50
+ crop_pct=1.0,
51
+ input_size=(3, 224, 224),
52
+ crop_mode='center'),
53
+ 'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
54
+ crop_pct=0.98,
55
+ input_size=(3, 224, 224),
56
+ crop_mode='center'),
57
+ 'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
58
+ crop_pct=0.93,
59
+ input_size=(3, 224, 224),
60
+ crop_mode='center'),
61
+ 'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
62
+ crop_pct=1.0,
63
+ input_size=(3, 224, 224),
64
+ crop_mode='center'),
65
+ 'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
66
+ crop_pct=1.0,
67
+ input_size=(3, 224, 224),
68
+ crop_mode='center'),
69
+ 'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
70
+ crop_pct=1.0,
71
+ input_size=(3, 224, 224),
72
+ crop_mode='center')
73
+ }
74
+
75
+
76
+ def window_partition(x, window_size):
77
+ """
78
+ Args:
79
+ x: (B, C, H, W)
80
+ window_size: window size
81
+ h_w: Height of window
82
+ w_w: Width of window
83
+ Returns:
84
+ local window features (num_windows*B, window_size*window_size, C)
85
+ """
86
+ B, C, H, W = x.shape
87
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
88
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
89
+ return windows
90
+
91
+
92
+ def window_reverse(windows, window_size, H, W):
93
+ """
94
+ Args:
95
+ windows: local window features (num_windows*B, window_size, window_size, C)
96
+ window_size: Window size
97
+ H: Height of image
98
+ W: Width of image
99
+ Returns:
100
+ x: (B, C, H, W)
101
+ """
102
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
103
+ x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
104
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
105
+ return x
106
+
107
+
108
+ def _load_state_dict(module, state_dict, strict=False, logger=None):
109
+ """Load state_dict to a module.
110
+
111
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
112
+ Default value for ``strict`` is set to ``False`` and the message for
113
+ param mismatch will be shown even if strict is False.
114
+
115
+ Args:
116
+ module (Module): Module that receives the state_dict.
117
+ state_dict (OrderedDict): Weights.
118
+ strict (bool): whether to strictly enforce that the keys
119
+ in :attr:`state_dict` match the keys returned by this module's
120
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
121
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
122
+ message. If not specified, print function will be used.
123
+ """
124
+ unexpected_keys = []
125
+ all_missing_keys = []
126
+ err_msg = []
127
+
128
+ metadata = getattr(state_dict, '_metadata', None)
129
+ state_dict = state_dict.copy()
130
+ if metadata is not None:
131
+ state_dict._metadata = metadata
132
+
133
+ def load(module, prefix=''):
134
+ local_metadata = {} if metadata is None else metadata.get(
135
+ prefix[:-1], {})
136
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
137
+ all_missing_keys, unexpected_keys,
138
+ err_msg)
139
+ for name, child in module._modules.items():
140
+ if child is not None:
141
+ load(child, prefix + name + '.')
142
+
143
+ load(module)
144
+ load = None
145
+ missing_keys = [
146
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
147
+ ]
148
+
149
+ if unexpected_keys:
150
+ err_msg.append('unexpected key in source '
151
+ f'state_dict: {", ".join(unexpected_keys)}\n')
152
+ if missing_keys:
153
+ err_msg.append(
154
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
155
+
156
+
157
+ if len(err_msg) > 0:
158
+ err_msg.insert(
159
+ 0, 'The model and loaded state dict do not match exactly\n')
160
+ err_msg = '\n'.join(err_msg)
161
+ if strict:
162
+ raise RuntimeError(err_msg)
163
+ elif logger is not None:
164
+ logger.warning(err_msg)
165
+ else:
166
+ print(err_msg)
167
+
168
+
169
+ def _load_checkpoint(model,
170
+ filename,
171
+ map_location='cpu',
172
+ strict=False,
173
+ logger=None):
174
+ """Load checkpoint from a file or URI.
175
+
176
+ Args:
177
+ model (Module): Module to load checkpoint.
178
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
179
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
180
+ details.
181
+ map_location (str): Same as :func:`torch.load`.
182
+ strict (bool): Whether to allow different params for the model and
183
+ checkpoint.
184
+ logger (:mod:`logging.Logger` or None): The logger for error message.
185
+
186
+ Returns:
187
+ dict or OrderedDict: The loaded checkpoint.
188
+ """
189
+ checkpoint = torch.load(filename, map_location=map_location)
190
+ if not isinstance(checkpoint, dict):
191
+ raise RuntimeError(
192
+ f'No state_dict found in checkpoint file {filename}')
193
+ if 'state_dict' in checkpoint:
194
+ state_dict = checkpoint['state_dict']
195
+ elif 'model' in checkpoint:
196
+ state_dict = checkpoint['model']
197
+ else:
198
+ state_dict = checkpoint
199
+ if list(state_dict.keys())[0].startswith('module.'):
200
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
201
+
202
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
203
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
204
+
205
+ _load_state_dict(model, state_dict, strict, logger)
206
+ return checkpoint
207
+
208
+
209
+ class Downsample(nn.Module):
210
+ """
211
+ Down-sampling block"
212
+ """
213
+
214
+ def __init__(self,
215
+ dim,
216
+ keep_dim=False,
217
+ ):
218
+ """
219
+ Args:
220
+ dim: feature size dimension.
221
+ norm_layer: normalization layer.
222
+ keep_dim: bool argument for maintaining the resolution.
223
+ """
224
+
225
+ super().__init__()
226
+ if keep_dim:
227
+ dim_out = dim
228
+ else:
229
+ dim_out = 2 * dim
230
+ self.reduction = nn.Sequential(
231
+ nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
232
+ )
233
+
234
+ def forward(self, x):
235
+ x = self.reduction(x)
236
+ return x
237
+
238
+
239
+ class PatchEmbed(nn.Module):
240
+ """
241
+ Patch embedding block"
242
+ """
243
+
244
+ def __init__(self, in_chans=3, in_dim=64, dim=96):
245
+ """
246
+ Args:
247
+ in_chans: number of input channels.
248
+ dim: feature size dimension.
249
+ """
250
+ # in_dim = 1
251
+ super().__init__()
252
+ self.proj = nn.Identity()
253
+ self.conv_down = nn.Sequential(
254
+ nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
255
+ nn.BatchNorm2d(in_dim, eps=1e-4),
256
+ nn.ReLU(),
257
+ nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
258
+ nn.BatchNorm2d(dim, eps=1e-4),
259
+ nn.ReLU()
260
+ )
261
+
262
+ def forward(self, x):
263
+ x = self.proj(x)
264
+ x = self.conv_down(x)
265
+ return x
266
+
267
+
268
+ class ConvBlock(nn.Module):
269
+
270
+ def __init__(self, dim,
271
+ drop_path=0.,
272
+ layer_scale=None,
273
+ kernel_size=3):
274
+ super().__init__()
275
+
276
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
277
+ self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
278
+ self.act1 = nn.GELU(approximate= 'tanh')
279
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
280
+ self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
281
+ self.layer_scale = layer_scale
282
+ if layer_scale is not None and type(layer_scale) in [int, float]:
283
+ self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
284
+ self.layer_scale = True
285
+ else:
286
+ self.layer_scale = False
287
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
288
+
289
+ def forward(self, x):
290
+ input = x
291
+ x = self.conv1(x)
292
+ x = self.norm1(x)
293
+ x = self.act1(x)
294
+ x = self.conv2(x)
295
+ x = self.norm2(x)
296
+ if self.layer_scale:
297
+ x = x * self.gamma.view(1, -1, 1, 1)
298
+ x = input + self.drop_path(x)
299
+ return x
300
+
301
+
302
+ class MambaVisionMixer(nn.Module):
303
+ def __init__(
304
+ self,
305
+ d_model,
306
+ d_state=16,
307
+ d_conv=4,
308
+ expand=2,
309
+ dt_rank="auto",
310
+ dt_min=0.001,
311
+ dt_max=0.1,
312
+ dt_init="random",
313
+ dt_scale=1.0,
314
+ dt_init_floor=1e-4,
315
+ conv_bias=True,
316
+ bias=False,
317
+ use_fast_path=True,
318
+ layer_idx=None,
319
+ device=None,
320
+ dtype=None,
321
+ ):
322
+ factory_kwargs = {"device": device, "dtype": dtype}
323
+ super().__init__()
324
+ self.d_model = d_model
325
+ self.d_state = d_state
326
+ self.d_conv = d_conv
327
+ self.expand = expand
328
+ self.d_inner = int(self.expand * self.d_model)
329
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
330
+ self.use_fast_path = use_fast_path
331
+ self.layer_idx = layer_idx
332
+ self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
333
+ self.x_proj = nn.Linear(
334
+ self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
335
+ )
336
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
337
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
338
+ if dt_init == "constant":
339
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
340
+ elif dt_init == "random":
341
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
342
+ else:
343
+ raise NotImplementedError
344
+ dt = torch.exp(
345
+ torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
346
+ + math.log(dt_min)
347
+ ).clamp(min=dt_init_floor)
348
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
349
+ with torch.no_grad():
350
+ self.dt_proj.bias.copy_(inv_dt)
351
+ self.dt_proj.bias._no_reinit = True
352
+ A = repeat(
353
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
354
+ "n -> d n",
355
+ d=self.d_inner//2,
356
+ ).contiguous()
357
+ A_log = torch.log(A)
358
+ self.A_log = nn.Parameter(A_log)
359
+ self.A_log._no_weight_decay = True
360
+ self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
361
+ self.D._no_weight_decay = True
362
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
363
+ self.conv1d_x = nn.Conv1d(
364
+ in_channels=self.d_inner//2,
365
+ out_channels=self.d_inner//2,
366
+ bias=conv_bias//2,
367
+ kernel_size=d_conv,
368
+ groups=self.d_inner//2,
369
+ **factory_kwargs,
370
+ )
371
+ self.conv1d_z = nn.Conv1d(
372
+ in_channels=self.d_inner//2,
373
+ out_channels=self.d_inner//2,
374
+ bias=conv_bias//2,
375
+ kernel_size=d_conv,
376
+ groups=self.d_inner//2,
377
+ **factory_kwargs,
378
+ )
379
+
380
+ def forward(self, hidden_states):
381
+ """
382
+ hidden_states: (B, L, D)
383
+ Returns: same shape as hidden_states
384
+ """
385
+ _, seqlen, _ = hidden_states.shape
386
+ xz = self.in_proj(hidden_states)
387
+ xz = rearrange(xz, "b l d -> b d l")
388
+ x, z = xz.chunk(2, dim=1)
389
+ A = -torch.exp(self.A_log.float())
390
+ x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
391
+ z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
392
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
393
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
394
+ dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
395
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
396
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
397
+ y = selective_scan_fn(x,
398
+ dt,
399
+ A,
400
+ B,
401
+ C,
402
+ self.D.float(),
403
+ z=None,
404
+ delta_bias=self.dt_proj.bias.float(),
405
+ delta_softplus=True,
406
+ return_last_state=None)
407
+
408
+ y = torch.cat([y, z], dim=1)
409
+ y = rearrange(y, "b d l -> b l d")
410
+ out = self.out_proj(y)
411
+ return out
412
+
413
+
414
+ class Attention(nn.Module):
415
+
416
+ def __init__(
417
+ self,
418
+ dim,
419
+ num_heads=8,
420
+ qkv_bias=False,
421
+ qk_norm=False,
422
+ attn_drop=0.,
423
+ proj_drop=0.,
424
+ norm_layer=nn.LayerNorm,
425
+ ):
426
+ super().__init__()
427
+ assert dim % num_heads == 0
428
+ self.num_heads = num_heads
429
+ self.head_dim = dim // num_heads
430
+ self.scale = self.head_dim ** -0.5
431
+ self.fused_attn = True
432
+
433
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
434
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
435
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
436
+ self.attn_drop = nn.Dropout(attn_drop)
437
+ self.proj = nn.Linear(dim, dim)
438
+ self.proj_drop = nn.Dropout(proj_drop)
439
+
440
+ def forward(self, x):
441
+ B, N, C = x.shape
442
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
443
+ q, k, v = qkv.unbind(0)
444
+ q, k = self.q_norm(q), self.k_norm(k)
445
+
446
+ if self.fused_attn:
447
+ x = F.scaled_dot_product_attention(
448
+ q, k, v,
449
+ dropout_p=self.attn_drop.p,
450
+ )
451
+ else:
452
+ q = q * self.scale
453
+ attn = q @ k.transpose(-2, -1)
454
+ attn = attn.softmax(dim=-1)
455
+ attn = self.attn_drop(attn)
456
+ x = attn @ v
457
+
458
+ x = x.transpose(1, 2).reshape(B, N, C)
459
+ x = self.proj(x)
460
+ x = self.proj_drop(x)
461
+ return x
462
+
463
+
464
+ class Block(nn.Module):
465
+ def __init__(self,
466
+ dim,
467
+ num_heads,
468
+ counter,
469
+ transformer_blocks,
470
+ mlp_ratio=4.,
471
+ qkv_bias=False,
472
+ qk_scale=False,
473
+ drop=0.,
474
+ attn_drop=0.,
475
+ drop_path=0.,
476
+ act_layer=nn.GELU,
477
+ norm_layer=nn.LayerNorm,
478
+ Mlp_block=Mlp,
479
+ layer_scale=None,
480
+ ):
481
+ super().__init__()
482
+ self.norm1 = norm_layer(dim)
483
+ if counter in transformer_blocks:
484
+ self.mixer = Attention(
485
+ dim,
486
+ num_heads=num_heads,
487
+ qkv_bias=qkv_bias,
488
+ qk_norm=qk_scale,
489
+ attn_drop=attn_drop,
490
+ proj_drop=drop,
491
+ norm_layer=norm_layer,
492
+ )
493
+ else:
494
+ self.mixer = MambaVisionMixer(d_model=dim,
495
+ d_state=8,
496
+ d_conv=3,
497
+ expand=1
498
+ )
499
+
500
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
501
+ self.norm2 = norm_layer(dim)
502
+ mlp_hidden_dim = int(dim * mlp_ratio)
503
+ self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
504
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
505
+ self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
506
+ self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
507
+
508
+ def forward(self, x):
509
+ x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x)))
510
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
511
+ return x
512
+
513
+
514
+ class MambaVisionLayer(nn.Module):
515
+ """
516
+ MambaVision layer"
517
+ """
518
+
519
+ def __init__(self,
520
+ dim,
521
+ depth,
522
+ num_heads,
523
+ window_size,
524
+ conv=False,
525
+ downsample=True,
526
+ mlp_ratio=4.,
527
+ qkv_bias=True,
528
+ qk_scale=None,
529
+ drop=0.,
530
+ attn_drop=0.,
531
+ drop_path=0.,
532
+ layer_scale=None,
533
+ layer_scale_conv=None,
534
+ transformer_blocks = [],
535
+ ):
536
+ """
537
+ Args:
538
+ dim: feature size dimension.
539
+ depth: number of layers in each stage.
540
+ window_size: window size in each stage.
541
+ conv: bool argument for conv stage flag.
542
+ downsample: bool argument for down-sampling.
543
+ mlp_ratio: MLP ratio.
544
+ num_heads: number of heads in each stage.
545
+ qkv_bias: bool argument for query, key, value learnable bias.
546
+ qk_scale: bool argument to scaling query, key.
547
+ drop: dropout rate.
548
+ attn_drop: attention dropout rate.
549
+ drop_path: drop path rate.
550
+ norm_layer: normalization layer.
551
+ layer_scale: layer scaling coefficient.
552
+ layer_scale_conv: conv layer scaling coefficient.
553
+ transformer_blocks: list of transformer blocks.
554
+ """
555
+
556
+ super().__init__()
557
+ self.conv = conv
558
+ self.transformer_block = False
559
+ if conv:
560
+ self.blocks = nn.ModuleList([ConvBlock(dim=dim,
561
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
562
+ layer_scale=layer_scale_conv)
563
+ for i in range(depth)])
564
+ self.transformer_block = False
565
+ else:
566
+ self.transformer_block = True
567
+ self.blocks = nn.ModuleList([Block(dim=dim,
568
+ counter=i,
569
+ transformer_blocks=transformer_blocks,
570
+ num_heads=num_heads,
571
+ mlp_ratio=mlp_ratio,
572
+ qkv_bias=qkv_bias,
573
+ qk_scale=qk_scale,
574
+ drop=drop,
575
+ attn_drop=attn_drop,
576
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
577
+ layer_scale=layer_scale)
578
+ for i in range(depth)])
579
+ self.transformer_block = True
580
+
581
+ self.downsample = None if not downsample else Downsample(dim=dim)
582
+ self.do_gt = False
583
+ self.window_size = window_size
584
+
585
+ def forward(self, x):
586
+ _, _, H, W = x.shape
587
+
588
+ if self.transformer_block:
589
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
590
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
591
+ if pad_r > 0 or pad_b > 0:
592
+ x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
593
+ _, _, Hp, Wp = x.shape
594
+ else:
595
+ Hp, Wp = H, W
596
+ x = window_partition(x, self.window_size)
597
+
598
+ for _, blk in enumerate(self.blocks):
599
+ x = blk(x)
600
+ if self.transformer_block:
601
+ x = window_reverse(x, self.window_size, Hp, Wp)
602
+ if pad_r > 0 or pad_b > 0:
603
+ x = x[:, :, :H, :W].contiguous()
604
+ if self.downsample is None:
605
+ return x, x
606
+ return self.downsample(x), x
607
+
608
+
609
+ class MambaVision(nn.Module):
610
+ """
611
+ MambaVision,
612
+ """
613
+
614
+ def __init__(self,
615
+ dim,
616
+ in_dim,
617
+ depths,
618
+ window_size,
619
+ mlp_ratio,
620
+ num_heads,
621
+ drop_path_rate=0.2,
622
+ in_chans=3,
623
+ num_classes=1000,
624
+ qkv_bias=True,
625
+ qk_scale=None,
626
+ drop_rate=0.,
627
+ attn_drop_rate=0.,
628
+ layer_scale=None,
629
+ layer_scale_conv=None,
630
+ **kwargs):
631
+ """
632
+ Args:
633
+ dim: feature size dimension.
634
+ depths: number of layers in each stage.
635
+ window_size: window size in each stage.
636
+ mlp_ratio: MLP ratio.
637
+ num_heads: number of heads in each stage.
638
+ drop_path_rate: drop path rate.
639
+ in_chans: number of input channels.
640
+ num_classes: number of classes.
641
+ qkv_bias: bool argument for query, key, value learnable bias.
642
+ qk_scale: bool argument to scaling query, key.
643
+ drop_rate: dropout rate.
644
+ attn_drop_rate: attention dropout rate.
645
+ norm_layer: normalization layer.
646
+ layer_scale: layer scaling coefficient.
647
+ layer_scale_conv: conv layer scaling coefficient.
648
+ """
649
+ super().__init__()
650
+ num_features = int(dim * 2 ** (len(depths) - 1))
651
+ self.num_classes = num_classes
652
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
653
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
654
+ self.levels = nn.ModuleList()
655
+ for i in range(len(depths)):
656
+ conv = True if (i == 0 or i == 1) else False
657
+ level = MambaVisionLayer(dim=int(dim * 2 ** i),
658
+ depth=depths[i],
659
+ num_heads=num_heads[i],
660
+ window_size=window_size[i],
661
+ mlp_ratio=mlp_ratio,
662
+ qkv_bias=qkv_bias,
663
+ qk_scale=qk_scale,
664
+ conv=conv,
665
+ drop=drop_rate,
666
+ attn_drop=attn_drop_rate,
667
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
668
+ downsample=(i < 3),
669
+ layer_scale=layer_scale,
670
+ layer_scale_conv=layer_scale_conv,
671
+ transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
672
+ )
673
+ self.levels.append(level)
674
+ self.norm = nn.BatchNorm2d(num_features)
675
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
676
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
677
+ self.apply(self._init_weights)
678
+
679
+ def _init_weights(self, m):
680
+ if isinstance(m, nn.Linear):
681
+ trunc_normal_(m.weight, std=.02)
682
+ if isinstance(m, nn.Linear) and m.bias is not None:
683
+ nn.init.constant_(m.bias, 0)
684
+ elif isinstance(m, nn.LayerNorm):
685
+ nn.init.constant_(m.bias, 0)
686
+ nn.init.constant_(m.weight, 1.0)
687
+ elif isinstance(m, LayerNorm2d):
688
+ nn.init.constant_(m.bias, 0)
689
+ nn.init.constant_(m.weight, 1.0)
690
+ elif isinstance(m, nn.BatchNorm2d):
691
+ nn.init.ones_(m.weight)
692
+ nn.init.zeros_(m.bias)
693
+
694
+ @torch.jit.ignore
695
+ def no_weight_decay_keywords(self):
696
+ return {'rpb'}
697
+
698
+ def forward_features(self, x):
699
+ x = self.patch_embed(x)
700
+ outs = []
701
+ for level in self.levels:
702
+ x, xo = level(x)
703
+ outs.append(xo)
704
+ x = self.norm(x)
705
+ x = self.avgpool(x)
706
+ x = torch.flatten(x, 1)
707
+ return x, outs
708
+
709
+ def forward(self, x):
710
+ x, outs = self.forward_features(x)
711
+ x = self.head(x)
712
+ return x
713
+
714
+ def _load_state_dict(self,
715
+ pretrained,
716
+ strict: bool = False):
717
+ _load_checkpoint(self,
718
+ pretrained,
719
+ strict=strict)
720
+
721
+
722
+ class MambaVisionModel(PreTrainedModel):
723
+ config_class = MambaVisionConfig
724
+
725
+ def __init__(self, config):
726
+ super().__init__(config)
727
+ self.model = MambaVision(
728
+ depths=config.depths,
729
+ num_heads=config.num_heads,
730
+ window_size=config.window_size,
731
+ dim=config.dim,
732
+ in_dim=config.in_dim,
733
+ mlp_ratio=config.mlp_ratio,
734
+ layer_scale=config.layer_scale,
735
+ layer_scale_conv=config.layer_scale_conv
736
+ )
737
+
738
+ def forward(self, tensor):
739
+ return self.model.forward_features(tensor)
740
+
741
+
742
+ class MambaVisionModelForImageClassification(PreTrainedModel):
743
+ config_class = MambaVisionConfig
744
+
745
+
746
+ def __init__(self, config):
747
+ super().__init__(config)
748
+ self.model = MambaVision(
749
+ depths=config.depths,
750
+ num_heads=config.num_heads,
751
+ window_size=config.window_size,
752
+ dim=config.dim,
753
+ in_dim=config.in_dim,
754
+ mlp_ratio=config.mlp_ratio,
755
+ layer_scale=config.layer_scale,
756
+ layer_scale_conv=config.layer_scale_conv
757
+ )
758
+
759
+ def forward(self, tensor, labels=None):
760
+ logits = self.model(tensor)
761
+ if labels is not None:
762
+ loss = torch.nn.cross_entropy(logits, labels)
763
+ return {"loss": loss, "logits": logits}
764
+ return {"logits": logits}