Manli commited on
Commit
cc057ef
1 Parent(s): 8bca717

Merge modeling files into a single one to avoid relative import

Browse files
Files changed (6) hide show
  1. config.json +1 -1
  2. configuration_xgenmm.py +0 -157
  3. image_processing_blip_3.py +150 -91
  4. modeling_xgenmm.py +1817 -38
  5. utils.py +0 -383
  6. vlm.py +0 -1314
config.json CHANGED
@@ -3,7 +3,7 @@
3
  "XGenMMModelForConditionalGeneration"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_xgenmm.XGenMMConfig",
7
  "AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
8
  },
9
  "model_type": "xgenmm",
 
3
  "XGenMMModelForConditionalGeneration"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "modeling_xgenmm.XGenMMConfig",
7
  "AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
8
  },
9
  "model_type": "xgenmm",
configuration_xgenmm.py DELETED
@@ -1,157 +0,0 @@
1
- from transformers import PretrainedConfig
2
- from transformers import logging
3
- from transformers import CONFIG_MAPPING
4
-
5
- logger = logging.get_logger(__name__)
6
-
7
- class XGenMMVisionEncoderConfig(PretrainedConfig):
8
- model_type = "xgenmm_vision_encoder"
9
-
10
- def __init__(self,
11
- model_name: str = 'google/siglip-so400m-patch14-384',
12
- **kwargs):
13
- self.model_name = model_name
14
- super().__init__(**kwargs)
15
-
16
-
17
- class XGenMMVisionTokenizerConfig(PretrainedConfig):
18
- model_type = "xgenmm_vision_tokenizer"
19
-
20
- def __init__(self,
21
- vis_feature_dim: int = 1152,
22
- lang_embedding_dim: int = 3072,
23
- num_vis_tokens: int = 128,
24
- image_aspect_ratio: str = 'none',
25
- **kwargs):
26
- self.vis_feature_dim = vis_feature_dim
27
- self.lang_embedding_dim = lang_embedding_dim
28
- self.num_vis_tokens = num_vis_tokens
29
- self.image_aspect_ratio = image_aspect_ratio
30
- super().__init__(**kwargs)
31
-
32
-
33
- class XGenMMConfig(PretrainedConfig):
34
- model_type = "xgenmm"
35
-
36
- def __init__(self,
37
- vision_encoder_config: dict = None,
38
- vision_tokenizer_config: dict = None,
39
- text_config: dict = None,
40
- **kwargs):
41
-
42
- if vision_encoder_config is None:
43
- vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
44
- logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.")
45
-
46
- if vision_tokenizer_config is None:
47
- vision_tokenizer_config = {}
48
- logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.")
49
-
50
- if text_config is None:
51
- text_config = {
52
- 'initial_tokenizer_len':32012,
53
- 'pad_token_id':32011,
54
- 'bos_token_id':1,
55
- 'eos_token_id':32000,
56
- 'vocab_size': 32064,
57
- 'hidden_size': 3072,
58
- 'intermediate_size': 8192,
59
- 'num_hidden_layers': 32,
60
- 'num_attention_heads': 32,
61
- 'num_key_value_heads': 32,
62
- 'resid_pdrop': 0.0,
63
- 'embd_pdrop': 0.0,
64
- 'attention_dropout': 0.0,
65
- 'hidden_act': 'silu',
66
- 'max_position_embeddings': 4096,
67
- 'original_max_position_embeddings': 4096,
68
- 'initializer_range': 0.02,
69
- 'rms_norm_eps': 1e-05,
70
- 'use_cache': True,
71
- 'rope_theta': 10000.0,
72
- 'rope_scaling': None,
73
- 'sliding_window': 2047,
74
- 'return_dict': True,
75
- 'output_hidden_states': False,
76
- 'output_attentions': False,
77
- 'torchscript': False,
78
- 'torch_dtype': 'bfloat16',
79
- 'use_bfloat16': False,
80
- 'tf_legacy_loss': False,
81
- 'pruned_heads': {},
82
- 'tie_word_embeddings': False,
83
- 'chunk_size_feed_forward': 0,
84
- 'is_encoder_decoder': False,
85
- 'is_decoder': False,
86
- 'cross_attention_hidden_size': None,
87
- 'add_cross_attention': False,
88
- 'tie_encoder_decoder': False,
89
- 'max_length': 20,
90
- 'min_length': 0,
91
- 'do_sample': False,
92
- 'early_stopping': False,
93
- 'num_beams': 1,
94
- 'num_beam_groups': 1,
95
- 'diversity_penalty': 0.0,
96
- 'temperature': 1.0,
97
- 'top_k': 50,
98
- 'top_p': 1.0,
99
- 'typical_p': 1.0,
100
- 'repetition_penalty': 1.0,
101
- 'length_penalty': 1.0,
102
- 'no_repeat_ngram_size': 0,
103
- 'encoder_no_repeat_ngram_size': 0,
104
- 'bad_words_ids': None,
105
- 'num_return_sequences': 1,
106
- 'output_scores': False,
107
- 'return_dict_in_generate': False,
108
- 'forced_bos_token_id': None,
109
- 'forced_eos_token_id': None,
110
- 'remove_invalid_values': False,
111
- 'exponential_decay_length_penalty': None,
112
- 'suppress_tokens': None,
113
- 'begin_suppress_tokens': None,
114
- 'finetuning_task': None,
115
- 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
116
- 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
117
- 'tokenizer_class': None,
118
- 'prefix': None,
119
- 'bos_token_id': 1,
120
- 'pad_token_id': 32000,
121
- 'eos_token_id': 32000,
122
- 'sep_token_id': None,
123
- 'decoder_start_token_id': None,
124
- 'task_specific_params': None,
125
- 'problem_type': None,
126
- 'model_type': 'phi3'
127
- }
128
- logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
129
-
130
- self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
131
-
132
- self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config)
133
-
134
- text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
135
- self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
136
-
137
- for key in ['initial_tokenizer_len', 'pad_token_id']:
138
- if key not in self.text_config.to_dict():
139
- raise ValueError(f"The key `{key}` is missing in the text_config.")
140
-
141
- super().__init__(**kwargs)
142
-
143
- @classmethod
144
- def from_vision_encoder_vision_tokenizer_text_configs(
145
- cls,
146
- vision_encoder_config: XGenMMVisionEncoderConfig,
147
- vision_tokenizer_config: XGenMMVisionTokenizerConfig,
148
- text_config: PretrainedConfig,
149
- **kwargs):
150
-
151
- return cls(
152
- vision_encoder_config=vision_encoder_config.to_dict(),
153
- vision_tokenizer_config=vision_tokenizer_config.to_dict(),
154
- text_config=text_config.to_dict(),
155
- **kwargs,
156
- )
157
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
image_processing_blip_3.py CHANGED
@@ -1,10 +1,19 @@
1
  import random
2
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
  import torchvision.transforms.functional as F
4
- from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
5
- CenterCrop, ColorJitter, Grayscale
 
 
 
 
 
 
 
 
 
6
  import numbers
7
- import torch
8
  import ast
9
  import math
10
  import numpy as np
@@ -13,11 +22,23 @@ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
13
  from transformers.image_utils import ImageInput
14
  from transformers.utils import TensorType
15
 
16
- from utils import expand2square
17
 
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class Blip3ImageProcessor(BaseImageProcessor):
20
-
21
  def __init__(
22
  self,
23
  do_resize: bool = True,
@@ -37,104 +58,116 @@ class Blip3ImageProcessor(BaseImageProcessor):
37
 
38
  self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
39
  self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
40
-
41
-
42
  @classmethod
43
- def resize(cls, image_size, resize_mode, interpolation='bicubic', fill_color=0):
44
- interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
45
- if resize_mode == 'longest':
 
 
 
 
46
  transforms = [
47
- ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
48
- CenterCropOrPad(image_size, fill=fill_color)
 
 
49
  ]
50
- elif resize_mode == 'squash':
51
  if isinstance(image_size, int):
52
  image_size = (image_size, image_size)
53
  transforms = [
54
  Resize(image_size, interpolation=interpolation_mode),
55
  ]
56
  else:
57
- assert resize_mode == 'shortest'
58
  if not isinstance(image_size, (tuple, list)):
59
  image_size = (image_size, image_size)
60
  if image_size[0] == image_size[1]:
61
  # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
62
- transforms = [
63
- Resize(image_size[0], interpolation=interpolation_mode)
64
- ]
65
  else:
66
  # resize shortest edge to matching target dim for non-square target
67
  transforms = [ResizeKeepRatio(image_size)]
68
  transforms += [CenterCrop(image_size)]
69
  return transforms
70
-
71
  @classmethod
72
  def convert_rgb(cls, image):
73
  return image.convert("RGB")
74
-
75
-
76
- def _preprocess(self,
77
- images: ImageInput
78
- ) -> torch.Tensor:
79
- transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
80
- transforms.extend([
81
- self.convert_rgb,
82
- ToTensor(),
83
- Normalize(mean=self.image_mean, std=self.image_std)
84
- ])
85
  composed_transforms = Compose(transforms)
86
  images_tensor = composed_transforms(images)
87
- return images_tensor
88
-
89
- def preprocess(self,
90
- images: ImageInput,
91
- return_tensors: Optional[Union[str, TensorType]] = None,
92
- **kwargs) -> BatchFeature:
93
- if 'image_aspect_ratio' in kwargs:
94
- image_aspect_ratio = kwargs['image_aspect_ratio']
 
 
95
  else:
96
- image_aspect_ratio = 'none'
97
  new_images = []
98
- if image_aspect_ratio == 'pad':
99
  for image in images:
100
- image = expand2square(image, tuple(int(x*255) for x in self.image_mean))
 
 
101
  image = self._preprocess(image)
102
  new_images.append(image)
103
- elif image_aspect_ratio == 'anyres':
104
  for image in images:
105
- image = process_anyres_image(image, self._preprocess, self.size,
106
- self.grids)
 
107
  new_images.append(image)
108
  else:
109
  for image in images:
110
  image = self._preprocess(image)
111
  new_images.append(image)
112
-
113
  if all(x.shape == new_images[0].shape for x in new_images):
114
  new_images = torch.stack(new_images, dim=0)
115
- if image_aspect_ratio == 'anyres':
116
- new_images = BatchFeature(data={"pixel_values": new_images}, tensor_type=return_tensors)
 
 
117
  else:
118
- new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(1).unsqueeze(0)}, tensor_type=return_tensors)
119
-
 
 
 
120
  return new_images
121
 
122
-
123
  class ResizeKeepRatio:
124
- """ Resize and Keep Ratio
125
 
126
  Copy & paste from `timm`
127
  """
128
 
129
  def __init__(
130
- self,
131
- size,
132
- longest=0.,
133
- interpolation=InterpolationMode.BICUBIC,
134
- random_scale_prob=0.,
135
- random_scale_range=(0.85, 1.05),
136
- random_aspect_prob=0.,
137
- random_aspect_range=(0.9, 1.11)
138
  ):
139
  if isinstance(size, (list, tuple)):
140
  self.size = tuple(size)
@@ -149,30 +182,36 @@ class ResizeKeepRatio:
149
 
150
  @staticmethod
151
  def get_params(
152
- img,
153
- target_size,
154
- longest,
155
- random_scale_prob=0.,
156
- random_scale_range=(0.85, 1.05),
157
- random_aspect_prob=0.,
158
- random_aspect_range=(0.9, 1.11)
159
  ):
160
- """Get parameters
161
- """
162
  source_size = img.size[::-1] # h, w
163
  h, w = source_size
164
  target_h, target_w = target_size
165
  ratio_h = h / target_h
166
  ratio_w = w / target_w
167
- ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
 
 
168
  if random_scale_prob > 0 and random.random() < random_scale_prob:
169
  ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
170
  ratio_factor = (ratio_factor, ratio_factor)
171
  else:
172
- ratio_factor = (1., 1.)
173
  if random_aspect_prob > 0 and random.random() < random_aspect_prob:
174
- aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
175
- ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
 
 
 
 
 
176
  size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
177
  return size
178
 
@@ -185,19 +224,24 @@ class ResizeKeepRatio:
185
  PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
186
  """
187
  size = self.get_params(
188
- img, self.size, self.longest,
189
- self.random_scale_prob, self.random_scale_range,
190
- self.random_aspect_prob, self.random_aspect_range
 
 
 
 
191
  )
192
  img = F.resize(img, size, self.interpolation)
193
  return img
194
 
195
  def __repr__(self):
196
- format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
197
- format_string += f', interpolation={self.interpolation})'
198
- format_string += f', longest={self.longest:.3f})'
199
  return format_string
200
 
 
201
  def _setup_size(size, error_msg):
202
  if isinstance(size, numbers.Number):
203
  return int(size), int(size)
@@ -210,7 +254,10 @@ def _setup_size(size, error_msg):
210
 
211
  return size
212
 
213
- def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
 
 
 
214
  """Center crops and/or pads the given image.
215
  If the image is torch Tensor, it is expected
216
  to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
@@ -248,7 +295,8 @@ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> tor
248
  crop_top = int(round((image_height - crop_height) / 2.0))
249
  crop_left = int(round((image_width - crop_width) / 2.0))
250
  return F.crop(img, crop_top, crop_left, crop_height, crop_width)
251
-
 
252
  class CenterCropOrPad(torch.nn.Module):
253
  """Crops the given image at the center.
254
  If the image is torch Tensor, it is expected
@@ -263,7 +311,9 @@ class CenterCropOrPad(torch.nn.Module):
263
 
264
  def __init__(self, size, fill=0):
265
  super().__init__()
266
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
 
 
267
  self.fill = fill
268
 
269
  def forward(self, img):
@@ -278,7 +328,8 @@ class CenterCropOrPad(torch.nn.Module):
278
 
279
  def __repr__(self) -> str:
280
  return f"{self.__class__.__name__}(size={self.size})"
281
-
 
282
  def process_anyres_image(image, processor, processor_size, grid_pinpoints):
283
  """
284
  Process an image with variable resolutions.
@@ -306,9 +357,8 @@ def process_anyres_image(image, processor, processor_size, grid_pinpoints):
306
  image_original_resize = image.resize((processor_size[0], processor_size[0]))
307
 
308
  image_patches = [image_original_resize] + patches
309
- image_patches = [processor(image_patch)
310
- for image_patch in image_patches]
311
- return torch.stack(image_patches, dim=0)
312
 
313
 
314
  def select_best_resolution(original_size, possible_resolutions):
@@ -325,21 +375,29 @@ def select_best_resolution(original_size, possible_resolutions):
325
  original_width, original_height = original_size
326
  best_fit = None
327
  max_effective_resolution = 0
328
- min_wasted_resolution = float('inf')
329
 
330
  for width, height in possible_resolutions:
331
  scale = min(width / original_width, height / original_height)
332
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
333
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
 
 
 
 
334
  wasted_resolution = (width * height) - effective_resolution
335
 
336
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
 
 
 
337
  max_effective_resolution = effective_resolution
338
  min_wasted_resolution = wasted_resolution
339
  best_fit = (width, height)
340
 
341
  return best_fit
342
 
 
343
  def resize_and_pad_image(image, target_resolution):
344
  """
345
  Resize and pad an image to a target resolution while maintaining aspect ratio.
@@ -367,13 +425,14 @@ def resize_and_pad_image(image, target_resolution):
367
  # Resize the image
368
  resized_image = image.resize((new_width, new_height))
369
 
370
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
371
  paste_x = (target_width - new_width) // 2
372
  paste_y = (target_height - new_height) // 2
373
  new_image.paste(resized_image, (paste_x, paste_y))
374
 
375
  return new_image
376
 
 
377
  def divide_to_patches(image, patch_size):
378
  """
379
  Divides an image into patches of a specified size.
@@ -393,4 +452,4 @@ def divide_to_patches(image, patch_size):
393
  patch = image.crop(box)
394
  patches.append(patch)
395
 
396
- return patches
 
1
  import random
2
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
  import torchvision.transforms.functional as F
4
+ from torchvision.transforms import (
5
+ Normalize,
6
+ Compose,
7
+ RandomResizedCrop,
8
+ InterpolationMode,
9
+ ToTensor,
10
+ Resize,
11
+ CenterCrop,
12
+ ColorJitter,
13
+ Grayscale,
14
+ )
15
  import numbers
16
+ import torch
17
  import ast
18
  import math
19
  import numpy as np
 
22
  from transformers.image_utils import ImageInput
23
  from transformers.utils import TensorType
24
 
 
25
 
26
+ def expand2square(pil_img, background_color):
27
+ width, height = pil_img.size
28
+ if width == height:
29
+ return pil_img
30
+ elif width > height:
31
+ result = Image.new(pil_img.mode, (width, width), background_color)
32
+ result.paste(pil_img, (0, (width - height) // 2))
33
+ return result
34
+ else:
35
+ result = Image.new(pil_img.mode, (height, height), background_color)
36
+ result.paste(pil_img, ((height - width) // 2, 0))
37
+ return result
38
+
39
+
40
  class Blip3ImageProcessor(BaseImageProcessor):
41
+
42
  def __init__(
43
  self,
44
  do_resize: bool = True,
 
58
 
59
  self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
60
  self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
61
+
 
62
  @classmethod
63
+ def resize(cls, image_size, resize_mode, interpolation="bicubic", fill_color=0):
64
+ interpolation_mode = (
65
+ InterpolationMode.BILINEAR
66
+ if interpolation == "bilinear"
67
+ else InterpolationMode.BICUBIC
68
+ )
69
+ if resize_mode == "longest":
70
  transforms = [
71
+ ResizeKeepRatio(
72
+ image_size, interpolation=interpolation_mode, longest=1
73
+ ),
74
+ CenterCropOrPad(image_size, fill=fill_color),
75
  ]
76
+ elif resize_mode == "squash":
77
  if isinstance(image_size, int):
78
  image_size = (image_size, image_size)
79
  transforms = [
80
  Resize(image_size, interpolation=interpolation_mode),
81
  ]
82
  else:
83
+ assert resize_mode == "shortest"
84
  if not isinstance(image_size, (tuple, list)):
85
  image_size = (image_size, image_size)
86
  if image_size[0] == image_size[1]:
87
  # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
88
+ transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
 
 
89
  else:
90
  # resize shortest edge to matching target dim for non-square target
91
  transforms = [ResizeKeepRatio(image_size)]
92
  transforms += [CenterCrop(image_size)]
93
  return transforms
94
+
95
  @classmethod
96
  def convert_rgb(cls, image):
97
  return image.convert("RGB")
98
+
99
+ def _preprocess(self, images: ImageInput) -> torch.Tensor:
100
+ transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
101
+ transforms.extend(
102
+ [
103
+ self.convert_rgb,
104
+ ToTensor(),
105
+ Normalize(mean=self.image_mean, std=self.image_std),
106
+ ]
107
+ )
 
108
  composed_transforms = Compose(transforms)
109
  images_tensor = composed_transforms(images)
110
+ return images_tensor
111
+
112
+ def preprocess(
113
+ self,
114
+ images: ImageInput,
115
+ return_tensors: Optional[Union[str, TensorType]] = None,
116
+ **kwargs,
117
+ ) -> BatchFeature:
118
+ if "image_aspect_ratio" in kwargs:
119
+ image_aspect_ratio = kwargs["image_aspect_ratio"]
120
  else:
121
+ image_aspect_ratio = "none"
122
  new_images = []
123
+ if image_aspect_ratio == "pad":
124
  for image in images:
125
+ image = expand2square(
126
+ image, tuple(int(x * 255) for x in self.image_mean)
127
+ )
128
  image = self._preprocess(image)
129
  new_images.append(image)
130
+ elif image_aspect_ratio == "anyres":
131
  for image in images:
132
+ image = process_anyres_image(
133
+ image, self._preprocess, self.size, self.grids
134
+ )
135
  new_images.append(image)
136
  else:
137
  for image in images:
138
  image = self._preprocess(image)
139
  new_images.append(image)
140
+
141
  if all(x.shape == new_images[0].shape for x in new_images):
142
  new_images = torch.stack(new_images, dim=0)
143
+ if image_aspect_ratio == "anyres":
144
+ new_images = BatchFeature(
145
+ data={"pixel_values": new_images}, tensor_type=return_tensors
146
+ )
147
  else:
148
+ new_images = BatchFeature(
149
+ data={"pixel_values": new_images.unsqueeze(1).unsqueeze(0)},
150
+ tensor_type=return_tensors,
151
+ )
152
+
153
  return new_images
154
 
155
+
156
  class ResizeKeepRatio:
157
+ """Resize and Keep Ratio
158
 
159
  Copy & paste from `timm`
160
  """
161
 
162
  def __init__(
163
+ self,
164
+ size,
165
+ longest=0.0,
166
+ interpolation=InterpolationMode.BICUBIC,
167
+ random_scale_prob=0.0,
168
+ random_scale_range=(0.85, 1.05),
169
+ random_aspect_prob=0.0,
170
+ random_aspect_range=(0.9, 1.11),
171
  ):
172
  if isinstance(size, (list, tuple)):
173
  self.size = tuple(size)
 
182
 
183
  @staticmethod
184
  def get_params(
185
+ img,
186
+ target_size,
187
+ longest,
188
+ random_scale_prob=0.0,
189
+ random_scale_range=(0.85, 1.05),
190
+ random_aspect_prob=0.0,
191
+ random_aspect_range=(0.9, 1.11),
192
  ):
193
+ """Get parameters"""
 
194
  source_size = img.size[::-1] # h, w
195
  h, w = source_size
196
  target_h, target_w = target_size
197
  ratio_h = h / target_h
198
  ratio_w = w / target_w
199
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
200
+ 1.0 - longest
201
+ )
202
  if random_scale_prob > 0 and random.random() < random_scale_prob:
203
  ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
204
  ratio_factor = (ratio_factor, ratio_factor)
205
  else:
206
+ ratio_factor = (1.0, 1.0)
207
  if random_aspect_prob > 0 and random.random() < random_aspect_prob:
208
+ aspect_factor = random.uniform(
209
+ random_aspect_range[0], random_aspect_range[1]
210
+ )
211
+ ratio_factor = (
212
+ ratio_factor[0] / aspect_factor,
213
+ ratio_factor[1] * aspect_factor,
214
+ )
215
  size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
216
  return size
217
 
 
224
  PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
225
  """
226
  size = self.get_params(
227
+ img,
228
+ self.size,
229
+ self.longest,
230
+ self.random_scale_prob,
231
+ self.random_scale_range,
232
+ self.random_aspect_prob,
233
+ self.random_aspect_range,
234
  )
235
  img = F.resize(img, size, self.interpolation)
236
  return img
237
 
238
  def __repr__(self):
239
+ format_string = self.__class__.__name__ + "(size={0}".format(self.size)
240
+ format_string += f", interpolation={self.interpolation})"
241
+ format_string += f", longest={self.longest:.3f})"
242
  return format_string
243
 
244
+
245
  def _setup_size(size, error_msg):
246
  if isinstance(size, numbers.Number):
247
  return int(size), int(size)
 
254
 
255
  return size
256
 
257
+
258
+ def center_crop_or_pad(
259
+ img: torch.Tensor, output_size: List[int], fill=0
260
+ ) -> torch.Tensor:
261
  """Center crops and/or pads the given image.
262
  If the image is torch Tensor, it is expected
263
  to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
 
295
  crop_top = int(round((image_height - crop_height) / 2.0))
296
  crop_left = int(round((image_width - crop_width) / 2.0))
297
  return F.crop(img, crop_top, crop_left, crop_height, crop_width)
298
+
299
+
300
  class CenterCropOrPad(torch.nn.Module):
301
  """Crops the given image at the center.
302
  If the image is torch Tensor, it is expected
 
311
 
312
  def __init__(self, size, fill=0):
313
  super().__init__()
314
+ self.size = _setup_size(
315
+ size, error_msg="Please provide only two dimensions (h, w) for size."
316
+ )
317
  self.fill = fill
318
 
319
  def forward(self, img):
 
328
 
329
  def __repr__(self) -> str:
330
  return f"{self.__class__.__name__}(size={self.size})"
331
+
332
+
333
  def process_anyres_image(image, processor, processor_size, grid_pinpoints):
334
  """
335
  Process an image with variable resolutions.
 
357
  image_original_resize = image.resize((processor_size[0], processor_size[0]))
358
 
359
  image_patches = [image_original_resize] + patches
360
+ image_patches = [processor(image_patch) for image_patch in image_patches]
361
+ return torch.stack(image_patches, dim=0)
 
362
 
363
 
364
  def select_best_resolution(original_size, possible_resolutions):
 
375
  original_width, original_height = original_size
376
  best_fit = None
377
  max_effective_resolution = 0
378
+ min_wasted_resolution = float("inf")
379
 
380
  for width, height in possible_resolutions:
381
  scale = min(width / original_width, height / original_height)
382
+ downscaled_width, downscaled_height = int(original_width * scale), int(
383
+ original_height * scale
384
+ )
385
+ effective_resolution = min(
386
+ downscaled_width * downscaled_height, original_width * original_height
387
+ )
388
  wasted_resolution = (width * height) - effective_resolution
389
 
390
+ if effective_resolution > max_effective_resolution or (
391
+ effective_resolution == max_effective_resolution
392
+ and wasted_resolution < min_wasted_resolution
393
+ ):
394
  max_effective_resolution = effective_resolution
395
  min_wasted_resolution = wasted_resolution
396
  best_fit = (width, height)
397
 
398
  return best_fit
399
 
400
+
401
  def resize_and_pad_image(image, target_resolution):
402
  """
403
  Resize and pad an image to a target resolution while maintaining aspect ratio.
 
425
  # Resize the image
426
  resized_image = image.resize((new_width, new_height))
427
 
428
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
429
  paste_x = (target_width - new_width) // 2
430
  paste_y = (target_height - new_height) // 2
431
  new_image.paste(resized_image, (paste_x, paste_y))
432
 
433
  return new_image
434
 
435
+
436
  def divide_to_patches(image, patch_size):
437
  """
438
  Divides an image into patches of a specified size.
 
452
  patch = image.crop(box)
453
  patches.append(patch)
454
 
455
+ return patches
modeling_xgenmm.py CHANGED
@@ -1,29 +1,1799 @@
1
- from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
 
 
 
 
 
2
  import torch
3
- import open_clip
 
 
4
  from typing import List, Optional, Tuple, Union
5
- from utils import check_embedding_fns
6
- from vlm import PerceiverResampler, XGenMMPerceiver
7
- from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class XGenMMVisionEncoder(PreTrainedModel):
10
  main_input_name = "pixel_values"
11
  config_class = XGenMMVisionEncoderConfig
12
-
13
  def __init__(self, config: XGenMMVisionEncoderConfig):
14
  super().__init__(config)
15
- if config.model_name != 'google/siglip-so400m-patch14-384':
16
- raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
 
 
17
  self.model = AutoModel.from_pretrained(config.model_name)
18
-
19
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
20
  # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
21
  return self.model.encode_image(pixel_values)
22
-
23
 
24
- # vision tokenizer
 
25
  class XGenMMVisionTokenizer(PreTrainedModel):
26
  config_class = XGenMMVisionTokenizerConfig
 
27
  def __init__(self, config: XGenMMVisionTokenizerConfig):
28
  super().__init__(config)
29
  self.model = PerceiverResampler(
@@ -31,48 +1801,56 @@ class XGenMMVisionTokenizer(PreTrainedModel):
31
  dim_inner=config.lang_embedding_dim,
32
  num_latents=config.num_vis_tokens,
33
  )
34
-
35
- def forward(self,
36
- vision_features: torch.Tensor,
37
- vision_attn_masks: torch.Tensor):
38
  return self.model(vision_features, vision_attn_masks)
39
-
 
40
  # XGenMM model
41
  class XGenMMModelForConditionalGeneration(PreTrainedModel):
42
  config_class = XGenMMConfig
43
-
44
  def __init__(self, config: XGenMMConfig):
45
  super().__init__(config)
46
-
47
  # vision encoder initialization
48
- vision_encoder = AutoModel.from_pretrained(config.vision_encoder_config.model_name).vision_model
49
-
50
- # language model initialization
 
 
51
  language_model = AutoModelForCausalLM.from_config(config.text_config)
52
  check_embedding_fns(language_model)
53
  # Update _tied_weights_keys using the base model used.
54
  if language_model._tied_weights_keys is not None:
55
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
56
-
 
 
57
  # vision tokenizer initialization
58
- if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
 
 
 
59
  overwrite = language_model.get_input_embeddings().weight.shape[1]
60
  config.vision_tokenizer_config.lang_embedding_dim = overwrite
61
- print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
62
-
 
 
63
  vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
64
 
65
  self.vlm = XGenMMPerceiver(
66
  vision_encoder=vision_encoder,
67
  vision_tokenizer=vision_tokenizer,
68
  lang_model=language_model,
69
- initial_tokenizer_len = config.text_config.initial_tokenizer_len,
70
- pad_token_id = config.text_config.pad_token_id,
71
- image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
72
  )
73
  # Initialize weights and apply final processing
74
  self.post_init()
75
-
76
  @torch.no_grad()
77
  def generate(
78
  self,
@@ -80,14 +1858,15 @@ class XGenMMModelForConditionalGeneration(PreTrainedModel):
80
  input_ids: Optional[torch.LongTensor] = None,
81
  attention_mask: Optional[torch.LongTensor] = None,
82
  **generate_kwargs,
83
- ) -> torch.LongTensor:
84
  self.vlm = self.vlm.eval()
85
  return self.vlm.generate(
86
- vision_x = pixel_values,
87
- lang_x = input_ids,
88
- attention_mask = attention_mask,
89
- **generate_kwargs)
90
-
 
91
  def update_special_tokens(self, tokenizer):
92
  tokenizer.add_special_tokens(
93
  {"additional_special_tokens": list(self.vlm.special_tokens.values())}
@@ -95,8 +1874,8 @@ class XGenMMModelForConditionalGeneration(PreTrainedModel):
95
  self.vlm.lang_model.config.vocab_size = len(tokenizer)
96
  self.vlm.set_special_token_ids(
97
  {
98
- v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
 
99
  }
100
  )
101
  return tokenizer
102
-
 
1
+ import ast
2
+ import math
3
+ from einops import rearrange, repeat
4
+ from einops_exts import rearrange_many
5
+ from einops import rearrange
6
+ from PIL import Image
7
  import torch
8
+ from torch import einsum, nn
9
+
10
+
11
  from typing import List, Optional, Tuple, Union
12
+ import torch.nn.functional as F
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+ from dataclasses import dataclass
15
+ from transformers import CLIPVisionModel
16
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
17
+ from transformers import PretrainedConfig, logging, CONFIG_MAPPING
18
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class XGenMMVisionEncoderConfig(PretrainedConfig):
25
+ model_type = "xgenmm_vision_encoder"
26
+
27
+ def __init__(self, model_name: str = "google/siglip-so400m-patch14-384", **kwargs):
28
+ self.model_name = model_name
29
+ super().__init__(**kwargs)
30
+
31
+
32
+ class XGenMMVisionTokenizerConfig(PretrainedConfig):
33
+ model_type = "xgenmm_vision_tokenizer"
34
+
35
+ def __init__(
36
+ self,
37
+ vis_feature_dim: int = 1152,
38
+ lang_embedding_dim: int = 3072,
39
+ num_vis_tokens: int = 128,
40
+ image_aspect_ratio: str = "none",
41
+ **kwargs,
42
+ ):
43
+ self.vis_feature_dim = vis_feature_dim
44
+ self.lang_embedding_dim = lang_embedding_dim
45
+ self.num_vis_tokens = num_vis_tokens
46
+ self.image_aspect_ratio = image_aspect_ratio
47
+ super().__init__(**kwargs)
48
+
49
+
50
+ class XGenMMConfig(PretrainedConfig):
51
+ model_type = "xgenmm"
52
+
53
+ def __init__(
54
+ self,
55
+ vision_encoder_config: dict = None,
56
+ vision_tokenizer_config: dict = None,
57
+ text_config: dict = None,
58
+ **kwargs,
59
+ ):
60
+
61
+ if vision_encoder_config is None:
62
+ vision_encoder_config = {
63
+ "image_aspect_ratio": "anyres",
64
+ "anyres_patch_sampling": True,
65
+ }
66
+ logger.info(
67
+ "vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
68
+ )
69
+
70
+ if vision_tokenizer_config is None:
71
+ vision_tokenizer_config = {}
72
+ logger.info(
73
+ "vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
74
+ )
75
+
76
+ if text_config is None:
77
+ text_config = {
78
+ "initial_tokenizer_len": 32012,
79
+ "pad_token_id": 32011,
80
+ "bos_token_id": 1,
81
+ "eos_token_id": 32000,
82
+ "vocab_size": 32064,
83
+ "hidden_size": 3072,
84
+ "intermediate_size": 8192,
85
+ "num_hidden_layers": 32,
86
+ "num_attention_heads": 32,
87
+ "num_key_value_heads": 32,
88
+ "resid_pdrop": 0.0,
89
+ "embd_pdrop": 0.0,
90
+ "attention_dropout": 0.0,
91
+ "hidden_act": "silu",
92
+ "max_position_embeddings": 4096,
93
+ "original_max_position_embeddings": 4096,
94
+ "initializer_range": 0.02,
95
+ "rms_norm_eps": 1e-05,
96
+ "use_cache": True,
97
+ "rope_theta": 10000.0,
98
+ "rope_scaling": None,
99
+ "sliding_window": 2047,
100
+ "return_dict": True,
101
+ "output_hidden_states": False,
102
+ "output_attentions": False,
103
+ "torchscript": False,
104
+ "torch_dtype": "bfloat16",
105
+ "use_bfloat16": False,
106
+ "tf_legacy_loss": False,
107
+ "pruned_heads": {},
108
+ "tie_word_embeddings": False,
109
+ "chunk_size_feed_forward": 0,
110
+ "is_encoder_decoder": False,
111
+ "is_decoder": False,
112
+ "cross_attention_hidden_size": None,
113
+ "add_cross_attention": False,
114
+ "tie_encoder_decoder": False,
115
+ "max_length": 20,
116
+ "min_length": 0,
117
+ "do_sample": False,
118
+ "early_stopping": False,
119
+ "num_beams": 1,
120
+ "num_beam_groups": 1,
121
+ "diversity_penalty": 0.0,
122
+ "temperature": 1.0,
123
+ "top_k": 50,
124
+ "top_p": 1.0,
125
+ "typical_p": 1.0,
126
+ "repetition_penalty": 1.0,
127
+ "length_penalty": 1.0,
128
+ "no_repeat_ngram_size": 0,
129
+ "encoder_no_repeat_ngram_size": 0,
130
+ "bad_words_ids": None,
131
+ "num_return_sequences": 1,
132
+ "output_scores": False,
133
+ "return_dict_in_generate": False,
134
+ "forced_bos_token_id": None,
135
+ "forced_eos_token_id": None,
136
+ "remove_invalid_values": False,
137
+ "exponential_decay_length_penalty": None,
138
+ "suppress_tokens": None,
139
+ "begin_suppress_tokens": None,
140
+ "finetuning_task": None,
141
+ "id2label": {0: "LABEL_0", 1: "LABEL_1"},
142
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
143
+ "tokenizer_class": None,
144
+ "prefix": None,
145
+ "bos_token_id": 1,
146
+ "pad_token_id": 32000,
147
+ "eos_token_id": 32000,
148
+ "sep_token_id": None,
149
+ "decoder_start_token_id": None,
150
+ "task_specific_params": None,
151
+ "problem_type": None,
152
+ "model_type": "phi3",
153
+ }
154
+ logger.info(
155
+ "text_config is None. Initializing the text config with default values (`Phi3Config`)."
156
+ )
157
+
158
+ self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
159
+
160
+ self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(
161
+ **vision_tokenizer_config
162
+ )
163
+
164
+ text_model_type = (
165
+ text_config["model_type"] if "model_type" in text_config else "phi3"
166
+ )
167
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
168
+
169
+ for key in ["initial_tokenizer_len", "pad_token_id"]:
170
+ if key not in self.text_config.to_dict():
171
+ raise ValueError(f"The key `{key}` is missing in the text_config.")
172
+
173
+ super().__init__(**kwargs)
174
+
175
+ @classmethod
176
+ def from_vision_encoder_vision_tokenizer_text_configs(
177
+ cls,
178
+ vision_encoder_config: XGenMMVisionEncoderConfig,
179
+ vision_tokenizer_config: XGenMMVisionTokenizerConfig,
180
+ text_config: PretrainedConfig,
181
+ **kwargs,
182
+ ):
183
+
184
+ return cls(
185
+ vision_encoder_config=vision_encoder_config.to_dict(),
186
+ vision_tokenizer_config=vision_tokenizer_config.to_dict(),
187
+ text_config=text_config.to_dict(),
188
+ **kwargs,
189
+ )
190
+
191
+
192
+ def has_fn(model, fn_name):
193
+ """Check if model has a function fn_name"""
194
+ return callable(getattr(model, fn_name, None))
195
+
196
+
197
+ def exists(val):
198
+ return val is not None
199
+
200
+
201
+ def num_params(module, filter_to_trainable=False):
202
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
203
+ if filter_to_trainable:
204
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
205
+ else:
206
+ return sum(p.numel() for p in module.parameters())
207
+
208
+
209
+ def hasattr_recursive(obj, att):
210
+ """
211
+ Check if obj has nested attribute
212
+ Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
213
+ """
214
+ if att == "":
215
+ return True
216
+ i = att.find(".")
217
+ if i < 0:
218
+ return hasattr(obj, att)
219
+ else:
220
+ try:
221
+ return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
222
+ except:
223
+ return False
224
+
225
+
226
+ def getattr_recursive(obj, att):
227
+ """
228
+ Return nested attribute of obj
229
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
230
+ """
231
+ if att == "":
232
+ return obj
233
+ i = att.find(".")
234
+ if i < 0:
235
+ return getattr(obj, att)
236
+ else:
237
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
238
+
239
+
240
+ def setattr_recursive(obj, att, val):
241
+ """
242
+ Set nested attribute of obj
243
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
244
+ """
245
+ if "." in att:
246
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
247
+ setattr(obj, att.split(".")[-1], val)
248
+
249
+
250
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
251
+ """
252
+ Stack a list of tensors with padding on one side
253
+ Args:
254
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
255
+ padding_value (int, optional): Value to pad with. Defaults to 0.
256
+ padding_side (str, optional): Side to pad on. Defaults to "right".
257
+ Returns:
258
+ torch.Tensor: Stacked tensors
259
+ """
260
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
261
+ padded_tensors = []
262
+ for tensor in list_of_tensors:
263
+ num_tokens = tensor.size(0)
264
+ if len(tensor.size()) == 1:
265
+ padding = torch.full(
266
+ (max_tokens - num_tokens,),
267
+ padding_value,
268
+ dtype=tensor.dtype,
269
+ device=tensor.device,
270
+ )
271
+ else:
272
+ padding = torch.full(
273
+ (max_tokens - num_tokens, tensor.size(1)),
274
+ padding_value,
275
+ dtype=tensor.dtype,
276
+ device=tensor.device,
277
+ )
278
+ padded_tensor = (
279
+ torch.cat((tensor, padding), dim=0)
280
+ if padding_side == "right"
281
+ else torch.cat((padding, tensor), dim=0)
282
+ )
283
+ padded_tensors.append(padded_tensor)
284
+ return torch.stack(padded_tensors)
285
+
286
+
287
+ def check_embedding_fns(lang_model):
288
+ """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
289
+ if not has_fn(lang_model, "get_input_embeddings"):
290
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
291
+ lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
292
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
293
+ lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
294
+ else:
295
+ raise ValueError(
296
+ "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
297
+ )
298
+
299
+ if not has_fn(lang_model, "set_input_embeddings"):
300
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
301
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
302
+ lang_model, "transformer.wte", x
303
+ )
304
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
305
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
306
+ lang_model, "model.decoder.embed_tokens", x
307
+ )
308
+ else:
309
+ raise ValueError(
310
+ "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
311
+ )
312
+
313
+ if not has_fn(lang_model, "get_output_embeddings"):
314
+ if hasattr_recursive(lang_model, "lm_head"):
315
+ lang_model.get_output_embeddings = lambda: lang_model.lm_head
316
+ else:
317
+ raise ValueError(
318
+ "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
319
+ )
320
+
321
+ if not has_fn(lang_model, "set_output_embeddings"):
322
+ if hasattr_recursive(lang_model, "lm_head"):
323
+ lang_model.set_output_embeddings = lambda x: setattr_recursive(
324
+ lang_model, "lm_head", x
325
+ )
326
+ else:
327
+ raise ValueError(
328
+ "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
329
+ )
330
+
331
+
332
+ def has_fn(model, fn_name):
333
+ """Check if model has a function fn_name"""
334
+ return callable(getattr(model, fn_name, None))
335
+
336
+
337
+ def unpad_image(tensor, original_size, keep_original_shape=False):
338
+ """
339
+ Unpads a PyTorch tensor of a padded and resized image.
340
+
341
+ Args:
342
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
343
+ original_size (tuple): The original size of the image (height, width).
344
+
345
+ Returns:
346
+ torch.Tensor: The unpadded image tensor.
347
+ """
348
+ original_width, original_height = original_size
349
+ current_height, current_width = tensor.shape[1:]
350
+
351
+ original_aspect_ratio = original_width / original_height
352
+ current_aspect_ratio = current_width / current_height
353
+
354
+ if original_aspect_ratio > current_aspect_ratio:
355
+ scale_factor = current_width / original_width
356
+ new_height = int(original_height * scale_factor)
357
+ padding = (current_height - new_height) // 2
358
+ if keep_original_shape:
359
+ attention_mask = torch.ones(
360
+ (current_height, current_width), device=tensor.device
361
+ )
362
+ attention_mask[:padding, :] = 0
363
+ attention_mask[current_height - padding :, :] = 0
364
+ return tensor, attention_mask
365
+ else:
366
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
367
+ return unpadded_tensor, None
368
+ else:
369
+ scale_factor = current_height / original_height
370
+ new_width = int(original_width * scale_factor)
371
+ padding = (current_width - new_width) // 2
372
+ if keep_original_shape:
373
+ attention_mask = torch.ones(
374
+ (current_height, current_width), device=tensor.device
375
+ )
376
+ attention_mask[:, :padding] = 0
377
+ attention_mask[:, current_width - padding :] = 0
378
+ return tensor, attention_mask
379
+ else:
380
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
381
+ return unpadded_tensor, None
382
+
383
+
384
+ def expand2square(pil_img, background_color):
385
+ width, height = pil_img.size
386
+ if width == height:
387
+ return pil_img
388
+ elif width > height:
389
+ result = Image.new(pil_img.mode, (width, width), background_color)
390
+ result.paste(pil_img, (0, (width - height) // 2))
391
+ return result
392
+ else:
393
+ result = Image.new(pil_img.mode, (height, height), background_color)
394
+ result.paste(pil_img, ((height - width) // 2, 0))
395
+ return result
396
+
397
+
398
+ class VisionTokenizer(nn.Module):
399
+ def __init__(self, dim_media, num_tokens_per_media):
400
+ super().__init__()
401
+ self.dim_media = dim_media
402
+ self.num_tokens_per_media = num_tokens_per_media
403
+
404
+
405
+ class PerceiverAttention(nn.Module):
406
+ def __init__(self, *, dim, dim_head=64, heads=8):
407
+ super().__init__()
408
+ self.scale = dim_head**-0.5
409
+ self.heads = heads
410
+ inner_dim = dim_head * heads
411
+
412
+ self.norm_media = nn.LayerNorm(dim)
413
+ self.norm_latents = nn.LayerNorm(dim)
414
+
415
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
416
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
417
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
418
+
419
+ def forward(self, x, latents, vision_attn_masks=None):
420
+ """
421
+ Args:
422
+ x (torch.Tensor): image features
423
+ shape (b, T, n1, D)
424
+ latent (torch.Tensor): latent features
425
+ shape (b, T, n2, D)
426
+ """
427
+ x = self.norm_media(x)
428
+ latents = self.norm_latents(latents)
429
+
430
+ h = self.heads
431
+
432
+ q = self.to_q(latents)
433
+ kv_input = torch.cat(
434
+ (x, latents), dim=-2
435
+ ) # TODO: Change the shape of vision attention mask according to this.
436
+ if vision_attn_masks is not None:
437
+ vision_attn_masks = torch.cat(
438
+ (
439
+ vision_attn_masks,
440
+ torch.ones(
441
+ (latents.shape[0], latents.shape[-2]),
442
+ dtype=latents.dtype,
443
+ device=latents.device,
444
+ ),
445
+ ),
446
+ dim=-1,
447
+ )
448
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
449
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
450
+ q = q * self.scale
451
+
452
+ # attention
453
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
454
+ # Apply vision attention mask here.
455
+ # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
456
+ if vision_attn_masks is not None:
457
+ attn_bias = torch.zeros(
458
+ (q.size(0), 1, 1, q.size(-2), k.size(-2)),
459
+ dtype=q.dtype,
460
+ device=q.device,
461
+ )
462
+ vision_attn_masks = repeat(
463
+ vision_attn_masks, "b n -> b 1 1 l n", l=q.size(-2)
464
+ )
465
+ attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
466
+ sim += attn_bias
467
+
468
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
469
+ attn = sim.softmax(dim=-1)
470
+
471
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
472
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
473
+ return self.to_out(out)
474
+
475
+
476
+ def FeedForward(dim, mult=4):
477
+ inner_dim = int(dim * mult)
478
+ return nn.Sequential(
479
+ nn.LayerNorm(dim),
480
+ nn.Linear(dim, inner_dim, bias=False),
481
+ nn.GELU(),
482
+ nn.Linear(inner_dim, dim, bias=False),
483
+ )
484
+
485
+
486
+ class PerceiverResampler(VisionTokenizer):
487
+ def __init__(
488
+ self,
489
+ *,
490
+ dim,
491
+ dim_inner=None,
492
+ depth=6,
493
+ dim_head=96,
494
+ heads=16,
495
+ num_latents=128,
496
+ max_num_media=None,
497
+ max_num_frames=None,
498
+ ff_mult=4,
499
+ ):
500
+ """
501
+ Perceiver module which takes in image features and outputs image tokens.
502
+ Args:
503
+ dim (int): dimension of the incoming image features
504
+ dim_inner (int, optional): final dimension to project the incoming image features to;
505
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
506
+ depth (int, optional): number of layers. Defaults to 6.
507
+ dim_head (int, optional): dimension of each head. Defaults to 64.
508
+ heads (int, optional): number of heads. Defaults to 8.
509
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
510
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
511
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
512
+ and keep positional embeddings for. If None, no positional embeddings are used.
513
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
514
+ and keep positional embeddings for. If None, no positional embeddings are used.
515
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
516
+ """
517
+ if dim_inner is not None:
518
+ projection = nn.Linear(dim, dim_inner)
519
+ else:
520
+ projection = None
521
+ dim_inner = dim
522
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
523
+ self.projection = projection
524
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
525
+
526
+ # positional embeddings
527
+ self.frame_embs = (
528
+ nn.Parameter(torch.randn(max_num_frames, dim))
529
+ if exists(max_num_frames)
530
+ else None
531
+ )
532
+ self.media_time_embs = (
533
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
534
+ if exists(max_num_media)
535
+ else None
536
+ )
537
+
538
+ self.layers = nn.ModuleList([])
539
+ for _ in range(depth):
540
+ self.layers.append(
541
+ nn.ModuleList(
542
+ [
543
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
544
+ FeedForward(dim=dim, mult=ff_mult),
545
+ ]
546
+ )
547
+ )
548
+
549
+ self.norm = nn.LayerNorm(dim)
550
+
551
+ def forward(self, x, vision_attn_masks=None):
552
+ """
553
+ Args:
554
+ x (torch.Tensor): image features
555
+ shape (b, T, F, v, D)
556
+ vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
557
+ shape (b, v)
558
+ Returns:
559
+ shape (b, T, n, D) where n is self.num_latents
560
+ """
561
+ b, T, F, v = x.shape[:4]
562
+
563
+ # frame and media time embeddings
564
+ if exists(self.frame_embs):
565
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
566
+ x = x + frame_embs
567
+ x = rearrange(
568
+ x, "b T F v d -> b T (F v) d"
569
+ ) # flatten the frame and spatial dimensions
570
+ if exists(self.media_time_embs):
571
+ x = x + self.media_time_embs[:T]
572
+
573
+ # blocks
574
+ latents = self.latents
575
+ latents = repeat(latents, "n d -> b T n d", b=b, T=T)
576
+ for attn, ff in self.layers:
577
+ latents = attn(x, latents, vision_attn_masks) + latents
578
+ latents = ff(latents) + latents
579
+
580
+ if exists(self.projection):
581
+ return self.projection(self.norm(latents))
582
+ else:
583
+ return self.norm(latents)
584
+
585
+
586
+ class DecoupledEmbedding(nn.Embedding):
587
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
588
+ """
589
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
590
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
591
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
592
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
593
+ """
594
+
595
+ def __init__(
596
+ self,
597
+ max_original_id: int,
598
+ num_additional_embeddings: int = 0,
599
+ _weight: torch.Tensor = None,
600
+ num_original_embeddings: int = None,
601
+ embedding_dim: int = None,
602
+ partially_freeze=True,
603
+ device=None,
604
+ dtype=None,
605
+ pad_token_id=None,
606
+ ) -> None:
607
+ """
608
+ Args:
609
+ max_original_id (`int`):
610
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
611
+ This is usually len(tokenizer) - 1 before additional tokens are added.
612
+ Note that this may not equal self.weight.shape[0]
613
+ num_additional_embeddings (`int`):
614
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
615
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
616
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
617
+ num_original_embeddings (`int`):
618
+ self.weight.shape[0]
619
+ embedding_dim (`int`):
620
+ The size of each embedding vector
621
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
622
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
623
+ padding_idx (`int`, *optional*):
624
+ The padding index (needs to be less than num_embeddings)
625
+
626
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
627
+ `max_norm` or `norm_type`. We are not supporting these.
628
+ """
629
+ # validate args
630
+ if pad_token_id is not None and pad_token_id > max_original_id:
631
+ raise ValueError(
632
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
633
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
634
+ )
635
+ if _weight is not None:
636
+ assert (num_original_embeddings is None) or (
637
+ _weight.shape[0] == num_original_embeddings
638
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
639
+ assert (embedding_dim is None) or (
640
+ _weight.shape[1] == embedding_dim
641
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
642
+ num_original_embeddings = _weight.shape[0]
643
+ embedding_dim = _weight.shape[1]
644
+ else:
645
+ assert (
646
+ num_original_embeddings is not None
647
+ ), "num_original_embeddings must be provided if _weight is not provided"
648
+ assert (
649
+ embedding_dim is not None
650
+ ), "embedding_dim must be provided if _weight is not provided"
651
+
652
+ super().__init__(
653
+ num_embeddings=num_original_embeddings,
654
+ embedding_dim=embedding_dim,
655
+ device=device,
656
+ dtype=dtype,
657
+ padding_idx=pad_token_id,
658
+ _weight=_weight,
659
+ )
660
+ self.max_original_id = max_original_id
661
+ self.padding_idx = pad_token_id
662
+ self.num_additional_embeddings = num_additional_embeddings
663
+ if self.num_additional_embeddings > 0:
664
+ self.additional_embedding = nn.Embedding(
665
+ num_embeddings=self.num_additional_embeddings,
666
+ embedding_dim=embedding_dim,
667
+ device=device,
668
+ dtype=dtype,
669
+ )
670
+ self.set_requires_grad(
671
+ require_regular_grad=not partially_freeze, require_additional_grad=True
672
+ )
673
+
674
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
675
+ """
676
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
677
+ """
678
+ self.weight.requires_grad_(require_regular_grad)
679
+ self.additional_embedding.requires_grad_(require_additional_grad)
680
+
681
+ def forward(self, input_ids):
682
+ """
683
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
684
+ self.additional_embedding.weight that is being trained.
685
+
686
+ in order to make a lookup of the input ids, we:
687
+ 1. find out the indices of the entries belonging to the 2nd embedding
688
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
689
+ embedding starts from 0 and not num_embeddings
690
+ 3. perform the 2nd embedding lookup
691
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
692
+ 5. perform the 1st embedding lookup
693
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
694
+
695
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
696
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
697
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
698
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
699
+ measure.
700
+
701
+ """
702
+ if self.num_additional_embeddings == 0:
703
+ return F.embedding(input_ids, self.weight)
704
+
705
+ # Clone so that we don't modify the original input_ids later on
706
+ input_ids = input_ids.clone()
707
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
708
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
709
+ additional_embeddings = self.additional_embedding(
710
+ input_ids_additional_vocab - self.max_original_id - 1
711
+ )
712
+
713
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
714
+ input_ids[additional_vocab_indices] = 0
715
+ full_vector = F.embedding(input_ids, self.weight)
716
+
717
+ # overwrite the records with high indices
718
+ full_vector[additional_vocab_indices] = additional_embeddings
719
+
720
+ return full_vector
721
+
722
+ def extra_repr(self) -> str:
723
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
724
+ self.max_original_id + 1,
725
+ self.num_additional_embeddings,
726
+ self.embedding_dim,
727
+ (not self.weight.requires_grad),
728
+ )
729
+
730
+
731
+ class DecoupledLinear(nn.Linear):
732
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
733
+ """
734
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
735
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
736
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
737
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
738
+ """
739
+
740
+ def __init__(
741
+ self,
742
+ max_original_id: int,
743
+ additional_out_features: int = 0,
744
+ _weight: torch.Tensor = None,
745
+ _bias: torch.Tensor = None,
746
+ in_features: int = None,
747
+ original_out_features: int = None,
748
+ bias: bool = True,
749
+ partially_freeze: bool = True,
750
+ device=None,
751
+ dtype=None,
752
+ ) -> None:
753
+ """
754
+ Args:
755
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
756
+ This is usually len(tokenizer) - 1 before additional tokens are added.
757
+ Note that this may not equal original_out_features - 1
758
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
759
+ If provided, this sets the `in_features` and `original_out_features` parameters.
760
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
761
+ in_features: int. Input hidden size.
762
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
763
+ additional_out_features: int. Number of additional trainable dimensions.
764
+ bias: bool. Whether to include a bias term.
765
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
766
+ """
767
+ # argument validation
768
+ if _weight is not None:
769
+ assert (_weight.shape[0] == original_out_features) or (
770
+ original_out_features is None
771
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
772
+ assert (_weight.shape[1] == in_features) or (
773
+ in_features is None
774
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
775
+ in_features = _weight.shape[1]
776
+ original_out_features = _weight.shape[0]
777
+ else:
778
+ assert (
779
+ in_features is not None
780
+ ), "in_features must be provided if _weight is not provided"
781
+ assert (
782
+ original_out_features is not None
783
+ ), "original_out_features must be provided if _weight is not provided"
784
+
785
+ if _bias is not None:
786
+ assert bias is True, "bias must be True if _bias is provided"
787
+
788
+ # initialize original linear
789
+ super().__init__(in_features, original_out_features, bias, device, dtype)
790
+
791
+ # set weight and bias manually
792
+ if _weight is not None:
793
+ self.weight = nn.Parameter(_weight)
794
+ if _bias is not None:
795
+ self.bias = nn.Parameter(_bias)
796
+
797
+ self.in_features = in_features
798
+ self.original_out_features = original_out_features
799
+ self.max_original_id = max_original_id
800
+
801
+ # initialize additional linear
802
+ self.additional_out_features = additional_out_features
803
+ self.has_bias = bias
804
+ if additional_out_features > 0:
805
+ self.additional_fc = nn.Linear(
806
+ in_features=in_features,
807
+ out_features=additional_out_features,
808
+ bias=self.has_bias,
809
+ device=device,
810
+ dtype=dtype,
811
+ )
812
+ self.set_requires_grad(
813
+ require_regular_grad=not partially_freeze, require_additional_grad=True
814
+ )
815
+
816
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
817
+ """
818
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
819
+ """
820
+ self.weight.requires_grad_(require_regular_grad)
821
+ if self.has_bias:
822
+ self.bias.requires_grad_(require_regular_grad)
823
+ self.additional_fc.requires_grad_(require_additional_grad)
824
+
825
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
826
+ output = F.linear(input, self.weight, self.bias)
827
+ output = output[..., : self.max_original_id + 1]
828
+
829
+ if self.additional_out_features > 0:
830
+ additional_features = F.linear(
831
+ input, self.additional_fc.weight, self.additional_fc.bias
832
+ )
833
+ output = torch.cat((output, additional_features), -1)
834
+ return output
835
+
836
+ def extra_repr(self) -> str:
837
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
838
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
839
+ self.in_features,
840
+ self.max_original_id + 1,
841
+ self.additional_out_features,
842
+ self.bias is not None,
843
+ (not self.weight.requires_grad or not self.bias.requires_grad),
844
+ )
845
+
846
+
847
+ class VLM(nn.Module):
848
+ """
849
+ Generic vision-language model (VLM) class.
850
+ A VLM consists of four components:
851
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
852
+ input: (B, T_img, F, C, H, W)
853
+ output: (B, T_img, F, v, d)
854
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
855
+ input: (B, T_img, F, v, d)
856
+ output: (B, T_img, n, d)
857
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
858
+ 4. A language model
859
+ """
860
+
861
+ def __init__(
862
+ self,
863
+ vision_encoder: nn.Module,
864
+ vision_tokenizer: nn.Module,
865
+ lang_model: nn.Module,
866
+ initial_tokenizer_len: int,
867
+ pad_token_id: int,
868
+ gradient_checkpointing: bool = False,
869
+ ):
870
+ """
871
+ Args:
872
+ vision_encoder (nn.Module): e.g. CLIP
873
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
874
+ lang_model (nn.Module): e.g. MPT
875
+ initial_tokenizer_len (int): size of the original tokenizer vocab
876
+ pad_token_id (int): id of the pad token
877
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
878
+ """
879
+ super().__init__()
880
+
881
+ # save dimension information
882
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
883
+ if hasattr(lang_model.config, "d_model"):
884
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
885
+ else:
886
+ self.lang_hidden_dim = lang_model.config.hidden_size
887
+ self.vis_embedding_dim = vision_tokenizer.dim_media
888
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
889
+
890
+ # core components
891
+ self.vision_encoder = vision_encoder
892
+ self.vision_tokenizer = vision_tokenizer
893
+ self.lang_model = lang_model
894
+
895
+ # lm embeddings
896
+ self.pad_token_id = pad_token_id
897
+ self.initial_tokenizer_len = initial_tokenizer_len
898
+ input_embeds = DecoupledEmbedding(
899
+ max_original_id=initial_tokenizer_len - 1,
900
+ num_additional_embeddings=len(self.special_tokens),
901
+ _weight=self.lang_model.get_input_embeddings().weight,
902
+ pad_token_id=self.pad_token_id,
903
+ )
904
+ if hasattr(input_embeds, "additional_embedding"):
905
+ input_embeds.additional_embedding.weight.data.normal_(
906
+ mean=0.0,
907
+ std=(
908
+ self.lang_model.config.initializer_range
909
+ if hasattr(self.lang_model.config, "initializer_range")
910
+ else 0.02
911
+ ),
912
+ )
913
+ self.lang_model.set_input_embeddings(input_embeds)
914
+
915
+ out_embeds = DecoupledLinear(
916
+ max_original_id=initial_tokenizer_len - 1,
917
+ additional_out_features=len(self.special_tokens),
918
+ _weight=self.lang_model.get_output_embeddings().weight,
919
+ _bias=(
920
+ self.lang_model.get_output_embeddings().bias
921
+ if hasattr(self.lang_model.get_output_embeddings(), "bias")
922
+ else None
923
+ ),
924
+ )
925
+ if hasattr(out_embeds, "additional_fc"):
926
+ out_embeds.additional_fc.weight.data.normal_(
927
+ mean=0.0,
928
+ std=(
929
+ self.lang_model.config.initializer_range
930
+ if hasattr(self.lang_model.config, "initializer_range")
931
+ else 0.02
932
+ ),
933
+ )
934
+ self.lang_model.set_output_embeddings(out_embeds)
935
+
936
+ # gradient checkpointing
937
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
938
+
939
+ def forward(
940
+ self,
941
+ vision_x: Optional[torch.Tensor],
942
+ lang_x: torch.Tensor,
943
+ attention_mask: Optional[torch.Tensor] = None,
944
+ labels: Optional[torch.Tensor] = None,
945
+ past_key_values: Optional[
946
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
947
+ ] = None,
948
+ past_media_locations: Optional[torch.Tensor] = None,
949
+ past_vision_tokens: Optional[torch.Tensor] = None,
950
+ use_cache: Optional[bool] = False,
951
+ **kwargs,
952
+ ):
953
+ """
954
+ Args:
955
+ vision_x: Vision input
956
+ shape (B, T_img, F, C, H, W) with F=1
957
+ only F = 1 is supported (single-frame videos)
958
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
959
+ only the first number of media tokens in lang_x are used
960
+ lang_x: Language input ids, with media tokens denoting where
961
+ visual media should be inserted.
962
+ shape (B, T_txt)
963
+ attention_mask: Attention mask. Defaults to None.
964
+ labels: Labels. Defaults to None.
965
+ shape (B, T_txt)
966
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
967
+ list of length = number of decoder layers in the LM
968
+ exact implementation depends on LM, see Hugging Face docs
969
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
970
+ shape (B, T_txt)
971
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
972
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
973
+ If True, includes key_values, media_locations, and vision_tokens in the output.
974
+ """
975
+ assert not (past_vision_tokens is None) ^ (
976
+ past_media_locations is None
977
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
978
+
979
+ # convert pixels to vision tokens
980
+ if vision_x is not None:
981
+ vision_features = self._encode_vision_x(vision_x=vision_x)
982
+ vision_tokens = self.vision_tokenizer(vision_features)
983
+ else:
984
+ vision_tokens = None
985
+
986
+ # fuse the vision and language tokens
987
+ new_inputs = self._prepare_inputs_for_forward(
988
+ vision_tokens=vision_tokens,
989
+ lang_x=lang_x,
990
+ attention_mask=attention_mask,
991
+ labels=labels,
992
+ past_key_values=past_key_values,
993
+ past_media_locations=past_media_locations,
994
+ padding_side="right",
995
+ past_vision_tokens=past_vision_tokens,
996
+ )
997
+ output = self.lang_model(
998
+ **new_inputs,
999
+ use_cache=use_cache,
1000
+ past_key_values=past_key_values,
1001
+ **kwargs,
1002
+ )
1003
+
1004
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
1005
+ # or to add the past_vision_tokens and past_media_locations to the output
1006
+ output = self._postprocess_outputs_from_forward(
1007
+ output=output,
1008
+ lang_x=lang_x,
1009
+ vision_tokens=vision_tokens,
1010
+ use_cache=use_cache,
1011
+ past_vision_tokens=past_vision_tokens,
1012
+ past_media_locations=past_media_locations,
1013
+ )
1014
+
1015
+ # postforward hooks
1016
+ self._post_forward_hook()
1017
+ return output
1018
+
1019
+ def _encode_vision_x_anyres(self, samples, device):
1020
+ assert self.anyres_grids is not None
1021
+ image_raw = samples[
1022
+ "image"
1023
+ ] # list of patch list in of shape [1, N_patch, C, H, W]
1024
+ image_sizes = samples["image_size"]
1025
+
1026
+ # Image_raw can be a list of list of patches, when a `samples` has multiple images.
1027
+ if isinstance(image_raw[0], list):
1028
+ images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
1029
+ image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
1030
+ else:
1031
+ # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
1032
+ # concate list of patches into one big patch for any res encoding.
1033
+ images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
1034
+ image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
1035
+ image = image.to(device)
1036
+
1037
+ with torch.no_grad():
1038
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1039
+ image_embeds = self.vision_encoder.trunk.forward_features(image)
1040
+ elif self.vision_encoder.__class__.__name__ in [
1041
+ "CLIPVisionModel",
1042
+ "SiglipVisionTransformer",
1043
+ ]:
1044
+ image_embeds = self.vision_encoder(image).last_hidden_state
1045
+ else:
1046
+ image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
1047
+
1048
+ if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(
1049
+ self.vision_encoder, SiglipVisionTransformer
1050
+ ):
1051
+ base_img_size = self.vision_encoder.config.image_size
1052
+ else:
1053
+ base_img_size = self.vision_encoder.image_size[0]
1054
+
1055
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1056
+ grid_size = self.vision_encoder.trunk.patch_embed.grid_size
1057
+ elif self.vision_encoder.__class__.__name__ in [
1058
+ "CLIPVisionModel",
1059
+ "SiglipVisionTransformer",
1060
+ ]:
1061
+ grid_size_base = (
1062
+ self.vision_encoder.config.image_size
1063
+ // self.vision_encoder.config.patch_size
1064
+ )
1065
+ grid_size = (grid_size_base, grid_size_base)
1066
+ else:
1067
+ grid_size = self.vision_encoder.grid_size
1068
+ height, width = grid_size
1069
+
1070
+ if not image_embeds.shape[1] == height * width:
1071
+ assert (
1072
+ image_embeds.shape[1] == height * width + 1
1073
+ ) # For vision encoders that has [CLS] token.
1074
+ image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
1075
+ n_vis_token_per_patch = image_embeds.shape[1]
1076
+
1077
+ # Split encoded patches and merge patch features
1078
+ # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
1079
+ split_sizes = [image.shape[0] for image in images]
1080
+ image_embeds = torch.split(image_embeds, split_sizes, dim=0)
1081
+ # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
1082
+ new_image_embeds = []
1083
+ patch_attn_masks = []
1084
+ max_n_img_token = -1
1085
+ for idx, patch_embeds in enumerate(image_embeds):
1086
+ if patch_embeds.shape[0] > 1:
1087
+ # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
1088
+ base_patch_embeds = patch_embeds[
1089
+ 0
1090
+ ] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
1091
+ patch_embeds = patch_embeds[1:]
1092
+
1093
+ assert height * width == base_patch_embeds.shape[0]
1094
+
1095
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
1096
+ image_sizes[idx], self.anyres_grids, base_img_size
1097
+ ) # Hardcoded grid_pinpoints.
1098
+ patch_embeds = patch_embeds.view(
1099
+ num_patch_height, num_patch_width, height, width, -1
1100
+ )
1101
+
1102
+ patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
1103
+ patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
1104
+ patch_embeds, patch_attn_mask = unpad_image(
1105
+ patch_embeds, image_sizes[idx], self.anyres_patch_sampling
1106
+ )
1107
+ if hasattr(self, "image_newline"):
1108
+ patch_embeds = torch.cat(
1109
+ (
1110
+ patch_embeds,
1111
+ self.image_newline[:, None, None].expand(
1112
+ *patch_embeds.shape[:-1], 1
1113
+ ),
1114
+ ),
1115
+ dim=-1,
1116
+ )
1117
+ if self.anyres_patch_sampling:
1118
+ patch_embeds = patch_embeds.view(
1119
+ -1, num_patch_height, num_patch_width, height * width
1120
+ )
1121
+ patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
1122
+ assert patch_attn_mask is not None
1123
+ patch_attn_mask = patch_attn_mask.view(
1124
+ num_patch_height, num_patch_width, height * width
1125
+ )
1126
+ patch_attn_mask = patch_attn_mask.flatten(0, 1)
1127
+ patch_embeds = torch.cat(
1128
+ (base_patch_embeds.unsqueeze(0), patch_embeds), dim=0
1129
+ )
1130
+ patch_attn_mask = torch.cat(
1131
+ (
1132
+ torch.ones(
1133
+ n_vis_token_per_patch, device=patch_embeds.device
1134
+ ).unsqueeze(0),
1135
+ patch_attn_mask,
1136
+ ),
1137
+ dim=0,
1138
+ )
1139
+ else:
1140
+ patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
1141
+ patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
1142
+ else:
1143
+ patch_embeds = (
1144
+ patch_embeds[0].unsqueeze(0)
1145
+ if self.anyres_patch_sampling
1146
+ else patch_embeds[0]
1147
+ )
1148
+ patch_attn_mask = (
1149
+ torch.ones(
1150
+ n_vis_token_per_patch, device=patch_embeds.device
1151
+ ).unsqueeze(0)
1152
+ if self.anyres_patch_sampling
1153
+ else None
1154
+ )
1155
+ if hasattr(self, "image_newline"):
1156
+ patch_embeds = torch.cat(
1157
+ (patch_embeds, self.image_newline[None]), dim=0
1158
+ )
1159
+ if not self.anyres_patch_sampling:
1160
+ max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
1161
+
1162
+ new_image_embeds.append(patch_embeds)
1163
+ patch_attn_masks.append(patch_attn_mask)
1164
+
1165
+ if self.anyres_patch_sampling:
1166
+ # Return individual patches for independent token downsampling.
1167
+ return new_image_embeds, patch_attn_masks
1168
+
1169
+ # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
1170
+ image_embeds = []
1171
+ image_atts = []
1172
+ for image_embed in new_image_embeds:
1173
+ n_img_token = image_embed.shape[0]
1174
+ img_attn = torch.ones(
1175
+ (max_n_img_token), dtype=torch.long, device=image_embed.device
1176
+ )
1177
+ if n_img_token < max_n_img_token:
1178
+ padded_embed = torch.zeros(
1179
+ (max_n_img_token, image_embed.shape[-1]),
1180
+ dtype=image_embed.dtype,
1181
+ device=image_embed.device,
1182
+ )
1183
+ padded_embed[:n_img_token, :] = image_embed
1184
+ img_attn[n_img_token:] = 0 # Mask out the padded entries.
1185
+ else:
1186
+ padded_embed = image_embed
1187
+ image_embeds.append(padded_embed)
1188
+ image_atts.append(img_attn)
1189
+ image_embeds = torch.stack(
1190
+ image_embeds, dim=0
1191
+ ) # Shape [B, N_tok_longest, C_dim]
1192
+ image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
1193
+ # TODO: reshape image_embeds and image_atts to "b T F v d"
1194
+ image_embeds = image_embeds[:, None, None, :, :]
1195
+ # image_atts = image_atts[:, None, None, :, :]
1196
+
1197
+ return image_embeds, image_atts
1198
+
1199
+ def _encode_vision_x(self, vision_x: torch.Tensor):
1200
+ """
1201
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
1202
+ Args:
1203
+ vision_x: Vision input
1204
+ shape (B, T_img, F, C, H, W)
1205
+ Images in the same chunk are collated along T_img, and frames are collated along F
1206
+ Currently only F=1 is supported (single-frame videos)
1207
+
1208
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
1209
+ """
1210
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
1211
+ b, T, F = vision_x.shape[:3]
1212
+
1213
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
1214
+ with torch.no_grad():
1215
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1216
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
1217
+ elif self.vision_encoder.__class__.__name__ in [
1218
+ "CLIPVisionModel",
1219
+ "SiglipVisionTransformer",
1220
+ ]:
1221
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
1222
+ else:
1223
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
1224
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
1225
+ return vision_x
1226
+
1227
+ def _concat_vision_cache(
1228
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
1229
+ ):
1230
+ """
1231
+ Helper function to include the past vision tokens and past media locations in the output.
1232
+ """
1233
+ if use_cache:
1234
+ if past_media_locations is not None and past_vision_tokens is not None:
1235
+ if vision_tokens is not None:
1236
+ updated_vision_tokens = torch.cat(
1237
+ [
1238
+ past_vision_tokens,
1239
+ vision_tokens,
1240
+ ],
1241
+ dim=1,
1242
+ )
1243
+ else:
1244
+ updated_vision_tokens = past_vision_tokens
1245
+ updated_media_locations = torch.cat(
1246
+ [
1247
+ past_media_locations,
1248
+ lang_x == self.media_token_id,
1249
+ ],
1250
+ dim=1,
1251
+ )
1252
+ else:
1253
+ updated_vision_tokens = vision_tokens
1254
+ updated_media_locations = lang_x == self.media_token_id
1255
+
1256
+ else:
1257
+ updated_vision_tokens = None
1258
+ updated_media_locations = None
1259
+
1260
+ return updated_vision_tokens, updated_media_locations
1261
+
1262
+ def generate(
1263
+ self,
1264
+ vision_x: torch.Tensor,
1265
+ lang_x: torch.Tensor,
1266
+ attention_mask: torch.Tensor = None,
1267
+ past_key_values: Optional[
1268
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1269
+ ] = None,
1270
+ past_media_locations: Optional[torch.Tensor] = None,
1271
+ past_vision_tokens: Optional[torch.Tensor] = None,
1272
+ **kwargs,
1273
+ ):
1274
+ """
1275
+ Generate text conditioned on vision and language inputs.
1276
+ Args:
1277
+ vision_x (torch.Tensor): Vision input
1278
+ shape (B, T_img, F, C, H, W)
1279
+ see documentation for forward
1280
+ lang_x (torch.Tensor): Language input
1281
+ shape (B, T_txt)
1282
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1283
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1284
+ Returns:
1285
+ torch.Tensor: lang_x with generated tokens appended to it
1286
+ """
1287
+ num_beams = kwargs.pop("num_beams", 1)
1288
+
1289
+ # convert pixels to vision tokens
1290
+ if vision_x is not None:
1291
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1292
+ vision_tokens = self.vision_tokenizer(vision_features)
1293
+ else:
1294
+ vision_tokens = None
1295
+
1296
+ # fuse the vision and language tokens
1297
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1298
+ # the total batch size is B * num_beams
1299
+ new_inputs = self._prepare_inputs_for_forward(
1300
+ vision_tokens=vision_tokens,
1301
+ lang_x=lang_x,
1302
+ attention_mask=attention_mask,
1303
+ past_key_values=past_key_values,
1304
+ past_media_locations=past_media_locations,
1305
+ past_vision_tokens=past_vision_tokens,
1306
+ padding_side="left",
1307
+ num_beams=num_beams,
1308
+ )
1309
+ output = self.lang_model.generate(
1310
+ **new_inputs,
1311
+ past_key_values=past_key_values,
1312
+ num_beams=num_beams,
1313
+ use_cache=True,
1314
+ **kwargs,
1315
+ )
1316
+ self._post_forward_hook()
1317
+ return output
1318
+
1319
+ @property
1320
+ def num_trainable_params(self):
1321
+ """Print the number of trainable parameters"""
1322
+ return num_params(self, filter_to_trainable=True)
1323
+
1324
+ def set_trainable(self):
1325
+ """
1326
+ Freeze appropriate parameters in the model.
1327
+ """
1328
+ raise NotImplementedError
1329
+
1330
+ def group_params_by_weight_decay(self):
1331
+ """
1332
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
1333
+ """
1334
+ params_with_wd, params_without_wd = [], []
1335
+ for n, p in self.named_parameters():
1336
+ if p.requires_grad:
1337
+ if self._should_apply_weight_decay(n):
1338
+ params_with_wd.append(p)
1339
+ else:
1340
+ params_without_wd.append(p)
1341
+ return params_with_wd, params_without_wd
1342
+
1343
+ def _should_apply_weight_decay(self, parameter_name):
1344
+ """
1345
+ Return whether weight decay should be applied to a parameter.
1346
+ """
1347
+ raise NotImplementedError
1348
+
1349
+ @property
1350
+ def special_tokens(self):
1351
+ """
1352
+ Returns a dict mapping from the attribute name of a special token to its string format,
1353
+ e.g. "media_token": "<image>"
1354
+ """
1355
+ assert (
1356
+ "media_token" in self._special_tokens
1357
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
1358
+ return self._special_tokens
1359
+
1360
+ @property
1361
+ def special_token_ids(self):
1362
+ """
1363
+ Returns a list of the special token ids
1364
+ """
1365
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
1366
+
1367
+ def set_special_token_ids(self, string_to_ids):
1368
+ """
1369
+ Args:
1370
+ string_to_ids (dict): mapping from token string to id
1371
+ """
1372
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
1373
+ for att_name, token_str in self.special_tokens.items():
1374
+ token_id = string_to_ids[token_str]
1375
+ setattr(self, f"{att_name}_id", token_id)
1376
+ setattr(self.lang_model, f"{att_name}_id", token_id)
1377
+
1378
+ def init_gradient_checkpointing(self):
1379
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1380
+ checkpoint_wrapper,
1381
+ CheckpointWrapper,
1382
+ CheckpointImpl,
1383
+ apply_activation_checkpointing,
1384
+ )
1385
+ from functools import partial
1386
+
1387
+ non_reentrant_wrapper = partial(
1388
+ checkpoint_wrapper,
1389
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
1390
+ )
1391
+ apply_activation_checkpointing(
1392
+ self,
1393
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
1394
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
1395
+ and not isinstance(m, CheckpointWrapper),
1396
+ )
1397
+
1398
+
1399
+ @dataclass
1400
+ class VLMOutputWithPast(CausalLMOutputWithPast):
1401
+ """
1402
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
1403
+ past_media_locations: Optional[torch.Tensor] = None,
1404
+ past_vision_tokens: Optional[torch.Tensor] = None,
1405
+ """
1406
+
1407
+ past_media_locations: Optional[torch.Tensor] = None
1408
+ past_vision_tokens: Optional[torch.Tensor] = None
1409
+
1410
+
1411
+ def exists(val):
1412
+ return val is not None
1413
+
1414
+
1415
+ def FeedForward(dim, mult=4):
1416
+ inner_dim = int(dim * mult)
1417
+ return nn.Sequential(
1418
+ nn.LayerNorm(dim),
1419
+ nn.Linear(dim, inner_dim, bias=False),
1420
+ nn.GELU(),
1421
+ nn.Linear(inner_dim, dim, bias=False),
1422
+ )
1423
+
1424
+
1425
+ class VLMWithLanguageStream(VLM):
1426
+ """
1427
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
1428
+ """
1429
+
1430
+ def __init__(
1431
+ self,
1432
+ vision_encoder: nn.Module,
1433
+ vision_tokenizer: nn.Module,
1434
+ lang_model: nn.Module,
1435
+ initial_tokenizer_len: int,
1436
+ pad_token_id: int,
1437
+ decoder_layers_attr_name: str = None,
1438
+ gradient_checkpointing: bool = False,
1439
+ ):
1440
+ super().__init__(
1441
+ vision_encoder=vision_encoder,
1442
+ vision_tokenizer=vision_tokenizer,
1443
+ lang_model=lang_model,
1444
+ initial_tokenizer_len=initial_tokenizer_len,
1445
+ pad_token_id=pad_token_id,
1446
+ gradient_checkpointing=gradient_checkpointing,
1447
+ )
1448
+ self.decoder_layers_attr_name = decoder_layers_attr_name
1449
+ if decoder_layers_attr_name is not None:
1450
+ for block in getattr_recursive(
1451
+ self.lang_model, self.decoder_layers_attr_name
1452
+ ):
1453
+ block._use_gradient_checkpointing = gradient_checkpointing
1454
+
1455
+ def _prepare_inputs_for_forward(
1456
+ self,
1457
+ vision_tokens: torch.Tensor,
1458
+ lang_x: torch.Tensor,
1459
+ attention_mask: torch.Tensor,
1460
+ labels: torch.Tensor = None,
1461
+ past_key_values=None,
1462
+ vision_attention_mask: Optional[torch.Tensor] = None,
1463
+ past_media_locations: torch.Tensor = None,
1464
+ past_vision_tokens: torch.Tensor = None,
1465
+ padding_side: str = "left",
1466
+ num_beams: int = 1,
1467
+ ):
1468
+ """
1469
+ Insert the vision tokens directly into the language stream/
1470
+ This requires us to modify the input_ids, attention_mask, and labels.
1471
+ """
1472
+ if past_key_values is not None:
1473
+ past_len = past_key_values[0][0].shape[2]
1474
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1475
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1476
+ + "Check that you've expanded the attention mask to account for past image tokens."
1477
+ )
1478
+
1479
+ if vision_tokens is None:
1480
+ return {
1481
+ "input_ids": lang_x,
1482
+ "attention_mask": attention_mask,
1483
+ "labels": labels,
1484
+ }
1485
+
1486
+ # get the language embeddings
1487
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1488
+
1489
+ # build up the multimodal embeddings
1490
+ B = lang_x.shape[0]
1491
+ has_labels = labels is not None
1492
+ multimodal_embeds = []
1493
+ multimodal_attention_mask = []
1494
+ multimodal_labels = [] if has_labels else None
1495
+ for i in range(B):
1496
+ # get index of <image> tokens in lang_x[i]
1497
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1498
+
1499
+ if len(image_token_idxs) == 0:
1500
+ multimodal_embeds.append(lang_embeds[i].clone())
1501
+ multimodal_attention_mask.append(attention_mask[i].clone())
1502
+ if has_labels:
1503
+ multimodal_labels.append(labels[i].clone())
1504
+ continue
1505
+
1506
+ # loop through the image_token_idxs and insert the vision tokens
1507
+ new_embed = lang_embeds[i].clone()
1508
+ new_attention_mask = (
1509
+ attention_mask[i].clone() if attention_mask is not None else None
1510
+ )
1511
+ if has_labels:
1512
+ new_label = labels[i].clone()
1513
+ print(vision_tokens.shape)
1514
+ for img_num, img_idx in enumerate(image_token_idxs):
1515
+ new_embed = torch.cat(
1516
+ (
1517
+ new_embed[:img_idx],
1518
+ vision_tokens[i][img_num],
1519
+ new_embed[img_idx + self.num_tokens_per_vis :],
1520
+ ),
1521
+ dim=0,
1522
+ )
1523
+ new_attention_mask = torch.cat(
1524
+ (
1525
+ new_attention_mask[:img_idx],
1526
+ torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1527
+ attention_mask.device
1528
+ ),
1529
+ new_attention_mask[img_idx + self.num_tokens_per_vis :],
1530
+ ),
1531
+ dim=0,
1532
+ )
1533
+ if has_labels:
1534
+ new_label = torch.cat(
1535
+ (
1536
+ new_label[:img_idx],
1537
+ torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1538
+ labels.device
1539
+ )
1540
+ * -100,
1541
+ new_label[img_idx + self.num_tokens_per_vis :],
1542
+ ),
1543
+ dim=0,
1544
+ )
1545
+ multimodal_embeds.append(new_embed)
1546
+ multimodal_attention_mask.append(new_attention_mask)
1547
+ if has_labels:
1548
+ multimodal_labels.append(new_label)
1549
+
1550
+ # stack
1551
+ multimodal_embeds = stack_with_padding(
1552
+ multimodal_embeds,
1553
+ padding_value=self.pad_token_id,
1554
+ padding_side=padding_side,
1555
+ )
1556
+ multimodal_attention_mask = stack_with_padding(
1557
+ multimodal_attention_mask,
1558
+ padding_value=0,
1559
+ padding_side=padding_side,
1560
+ )
1561
+ if has_labels:
1562
+ multimodal_labels = stack_with_padding(
1563
+ multimodal_labels,
1564
+ padding_value=-100,
1565
+ padding_side=padding_side,
1566
+ )
1567
+
1568
+ return {
1569
+ "inputs_embeds": multimodal_embeds,
1570
+ "attention_mask": multimodal_attention_mask,
1571
+ "labels": multimodal_labels,
1572
+ }
1573
+
1574
+ def _postprocess_outputs_from_forward(
1575
+ self,
1576
+ output: CausalLMOutputWithPast,
1577
+ lang_x: torch.Tensor,
1578
+ vision_tokens: torch.Tensor,
1579
+ past_vision_tokens: torch.Tensor,
1580
+ past_media_locations: torch.Tensor,
1581
+ use_cache: bool = False,
1582
+ ):
1583
+ # Include the past vision tokens and past media locations in the output
1584
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1585
+ lang_x=lang_x,
1586
+ vision_tokens=vision_tokens,
1587
+ past_vision_tokens=past_vision_tokens,
1588
+ past_media_locations=past_media_locations,
1589
+ use_cache=use_cache,
1590
+ )
1591
+
1592
+ # return logits that are the same shape as the original input_ids
1593
+ logits = output.logits
1594
+ batch_logits = []
1595
+ B, T_txt = lang_x.shape
1596
+ for i in range(B):
1597
+ sequence_logits = []
1598
+ logits_j = 0
1599
+ for j in range(T_txt):
1600
+ if lang_x[i, j] != self.media_token_id:
1601
+ sequence_logits.append(logits[i, logits_j])
1602
+ logits_j += 1
1603
+ else:
1604
+ # append the logit for the first image token, then skip over the rest
1605
+ # note: the model actually learns to predict <im_patch>, not <image>
1606
+ sequence_logits.append(logits[i, logits_j])
1607
+ logits_j += self.num_tokens_per_vis
1608
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1609
+ batch_logits.append(sequence_logits)
1610
+
1611
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1612
+ # The final logits shape should be the same as the original input_ids shape
1613
+ assert batch_logits.shape[:2] == (B, T_txt)
1614
+
1615
+ # assemble the output
1616
+ output = VLMOutputWithPast(
1617
+ loss=output.loss,
1618
+ logits=batch_logits,
1619
+ past_key_values=output.past_key_values,
1620
+ hidden_states=output.hidden_states,
1621
+ attentions=output.attentions,
1622
+ past_media_locations=updated_media_locations,
1623
+ past_vision_tokens=updated_vision_tokens,
1624
+ )
1625
+
1626
+ return output
1627
+
1628
+ def _post_forward_hook(self):
1629
+ pass
1630
+
1631
+ @property
1632
+ def num_params_per_module(self):
1633
+ """Print the number of parameters per module in the model"""
1634
+ return "\n".join(
1635
+ [
1636
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1637
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1638
+ f"Language model: {num_params(self.lang_model):,} parameters",
1639
+ ]
1640
+ )
1641
+
1642
+ @property
1643
+ def num_trainable_params_per_module(self):
1644
+ """Print the number of trainable parameters per module in the model"""
1645
+ return "\n".join(
1646
+ [
1647
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1648
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1649
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1650
+ ]
1651
+ )
1652
+
1653
+
1654
+ class XGenMMPerceiver(VLMWithLanguageStream):
1655
+ def __init__(
1656
+ self,
1657
+ vision_encoder: nn.Module,
1658
+ vision_tokenizer: nn.Module,
1659
+ lang_model: nn.Module,
1660
+ initial_tokenizer_len: int,
1661
+ pad_token_id: int,
1662
+ decoder_layers_attr_name: str = None,
1663
+ gradient_checkpointing: bool = False,
1664
+ image_aspect_ratio: str = "none",
1665
+ ):
1666
+ """
1667
+ Args:
1668
+ vision_encoder (nn.Module): HF CLIPModel
1669
+ lang_encoder (nn.Module): HF causal language model
1670
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1671
+ initial_tokenizer_len (int): size of the tokenizer vocab
1672
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1673
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1674
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1675
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1676
+ """
1677
+ self._special_tokens = {
1678
+ "media_token": "<image>",
1679
+ "image_placeholder_token": "<image placeholder>",
1680
+ "end_of_trunk_token": "<|endofchunk|>",
1681
+ }
1682
+ lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1683
+ super().__init__(
1684
+ vision_encoder=vision_encoder,
1685
+ vision_tokenizer=vision_tokenizer,
1686
+ lang_model=lang_model,
1687
+ initial_tokenizer_len=initial_tokenizer_len,
1688
+ gradient_checkpointing=gradient_checkpointing,
1689
+ decoder_layers_attr_name=decoder_layers_attr_name,
1690
+ pad_token_id=pad_token_id,
1691
+ )
1692
+ self.image_aspect_ratio = image_aspect_ratio
1693
+
1694
+ def set_trainable(self):
1695
+ """
1696
+ Unfreeze everything except the vision_encoder
1697
+ """
1698
+ self.requires_grad_(True)
1699
+ self.vision_encoder.requires_grad_(False)
1700
+
1701
+ def _should_apply_weight_decay(self, parameter_name):
1702
+ """
1703
+ Kosmos applies 0.01 weight deacy to everything
1704
+ """
1705
+ return True
1706
+
1707
+ def generate(
1708
+ self,
1709
+ vision_x: torch.Tensor,
1710
+ lang_x: torch.Tensor,
1711
+ image_size: Optional[Tuple] = None,
1712
+ attention_mask: torch.Tensor = None,
1713
+ past_key_values: Optional[
1714
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1715
+ ] = None,
1716
+ past_media_locations: Optional[torch.Tensor] = None,
1717
+ past_vision_tokens: Optional[torch.Tensor] = None,
1718
+ **kwargs,
1719
+ ):
1720
+ """
1721
+ Generate text conditioned on vision and language inputs.
1722
+ Args:
1723
+ vision_x (torch.Tensor): Vision input
1724
+ shape (B, T_img, F, C, H, W)
1725
+ see documentation for forward
1726
+ lang_x (torch.Tensor): Language input
1727
+ shape (B, T_txt)
1728
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1729
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1730
+ Returns:
1731
+ torch.Tensor: lang_x with generated tokens appended to it
1732
+ """
1733
+ num_beams = kwargs.pop("num_beams", 1)
1734
+
1735
+ # convert pixels to vision tokens
1736
+ vision_attention_mask = None
1737
+ if vision_x is not None:
1738
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1739
+ vision_tokens = self.vision_tokenizer(vision_features)
1740
+ else:
1741
+ vision_tokens = None
1742
+
1743
+ # fuse the vision and language tokens
1744
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1745
+ # the total batch size is B * num_beams
1746
+ new_inputs = self._prepare_inputs_for_forward(
1747
+ vision_tokens=vision_tokens,
1748
+ lang_x=lang_x,
1749
+ attention_mask=attention_mask,
1750
+ vision_attention_mask=vision_attention_mask,
1751
+ past_key_values=past_key_values,
1752
+ past_media_locations=past_media_locations,
1753
+ past_vision_tokens=past_vision_tokens,
1754
+ padding_side="left",
1755
+ num_beams=num_beams,
1756
+ )
1757
+ if past_key_values is not None:
1758
+ output = self.lang_model.generate(
1759
+ **new_inputs,
1760
+ past_key_values=past_key_values,
1761
+ num_beams=num_beams,
1762
+ use_cache=True,
1763
+ **kwargs,
1764
+ )
1765
+ else:
1766
+ output = self.lang_model.generate(
1767
+ **new_inputs,
1768
+ num_beams=num_beams,
1769
+ use_cache=True,
1770
+ **kwargs,
1771
+ )
1772
+ self._post_forward_hook()
1773
+ return output
1774
+
1775
 
1776
  class XGenMMVisionEncoder(PreTrainedModel):
1777
  main_input_name = "pixel_values"
1778
  config_class = XGenMMVisionEncoderConfig
1779
+
1780
  def __init__(self, config: XGenMMVisionEncoderConfig):
1781
  super().__init__(config)
1782
+ if config.model_name != "google/siglip-so400m-patch14-384":
1783
+ raise ValueError(
1784
+ f"Unsupported model {config.model_name}. New vision models will be added soon."
1785
+ )
1786
  self.model = AutoModel.from_pretrained(config.model_name)
1787
+
1788
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1789
  # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
1790
  return self.model.encode_image(pixel_values)
 
1791
 
1792
+
1793
+ # vision tokenizer
1794
  class XGenMMVisionTokenizer(PreTrainedModel):
1795
  config_class = XGenMMVisionTokenizerConfig
1796
+
1797
  def __init__(self, config: XGenMMVisionTokenizerConfig):
1798
  super().__init__(config)
1799
  self.model = PerceiverResampler(
 
1801
  dim_inner=config.lang_embedding_dim,
1802
  num_latents=config.num_vis_tokens,
1803
  )
1804
+
1805
+ def forward(self, vision_features: torch.Tensor, vision_attn_masks: torch.Tensor):
 
 
1806
  return self.model(vision_features, vision_attn_masks)
1807
+
1808
+
1809
  # XGenMM model
1810
  class XGenMMModelForConditionalGeneration(PreTrainedModel):
1811
  config_class = XGenMMConfig
1812
+
1813
  def __init__(self, config: XGenMMConfig):
1814
  super().__init__(config)
1815
+
1816
  # vision encoder initialization
1817
+ vision_encoder = AutoModel.from_pretrained(
1818
+ config.vision_encoder_config.model_name
1819
+ ).vision_model
1820
+
1821
+ # language model initialization
1822
  language_model = AutoModelForCausalLM.from_config(config.text_config)
1823
  check_embedding_fns(language_model)
1824
  # Update _tied_weights_keys using the base model used.
1825
  if language_model._tied_weights_keys is not None:
1826
+ self._tied_weights_keys = [
1827
+ f"language_model.{k}" for k in language_model._tied_weights_keys
1828
+ ]
1829
+
1830
  # vision tokenizer initialization
1831
+ if (
1832
+ config.vision_tokenizer_config.lang_embedding_dim
1833
+ != language_model.get_input_embeddings().weight.shape[1]
1834
+ ):
1835
  overwrite = language_model.get_input_embeddings().weight.shape[1]
1836
  config.vision_tokenizer_config.lang_embedding_dim = overwrite
1837
+ print(
1838
+ f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}."
1839
+ )
1840
+
1841
  vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
1842
 
1843
  self.vlm = XGenMMPerceiver(
1844
  vision_encoder=vision_encoder,
1845
  vision_tokenizer=vision_tokenizer,
1846
  lang_model=language_model,
1847
+ initial_tokenizer_len=config.text_config.initial_tokenizer_len,
1848
+ pad_token_id=config.text_config.pad_token_id,
1849
+ image_aspect_ratio=config.vision_encoder_config.image_aspect_ratio,
1850
  )
1851
  # Initialize weights and apply final processing
1852
  self.post_init()
1853
+
1854
  @torch.no_grad()
1855
  def generate(
1856
  self,
 
1858
  input_ids: Optional[torch.LongTensor] = None,
1859
  attention_mask: Optional[torch.LongTensor] = None,
1860
  **generate_kwargs,
1861
+ ) -> torch.LongTensor:
1862
  self.vlm = self.vlm.eval()
1863
  return self.vlm.generate(
1864
+ vision_x=pixel_values,
1865
+ lang_x=input_ids,
1866
+ attention_mask=attention_mask,
1867
+ **generate_kwargs,
1868
+ )
1869
+
1870
  def update_special_tokens(self, tokenizer):
1871
  tokenizer.add_special_tokens(
1872
  {"additional_special_tokens": list(self.vlm.special_tokens.values())}
 
1874
  self.vlm.lang_model.config.vocab_size = len(tokenizer)
1875
  self.vlm.set_special_token_ids(
1876
  {
1877
+ v: tokenizer.convert_tokens_to_ids(v)
1878
+ for v in self.vlm.special_tokens.values()
1879
  }
1880
  )
1881
  return tokenizer
 
utils.py DELETED
@@ -1,383 +0,0 @@
1
- import torch
2
- import ast
3
- import math
4
- from PIL import Image
5
- from packaging.version import Version
6
-
7
- def has_fn(model, fn_name):
8
- """Check if model has a function fn_name"""
9
- return callable(getattr(model, fn_name, None))
10
-
11
- def exists(val):
12
- return val is not None
13
-
14
- def num_params(module, filter_to_trainable=False):
15
- """Returns the number of parameters in the module, or optionally only the trainable parameters"""
16
- if filter_to_trainable:
17
- return sum(p.numel() for p in module.parameters() if p.requires_grad)
18
- else:
19
- return sum(p.numel() for p in module.parameters())
20
-
21
- def hasattr_recursive(obj, att):
22
- """
23
- Check if obj has nested attribute
24
- Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
- """
26
- if att == "":
27
- return True
28
- i = att.find(".")
29
- if i < 0:
30
- return hasattr(obj, att)
31
- else:
32
- try:
33
- return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
- except:
35
- return False
36
-
37
- def getattr_recursive(obj, att):
38
- """
39
- Return nested attribute of obj
40
- Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
41
- """
42
- if att == "":
43
- return obj
44
- i = att.find(".")
45
- if i < 0:
46
- return getattr(obj, att)
47
- else:
48
- return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
49
-
50
-
51
- def setattr_recursive(obj, att, val):
52
- """
53
- Set nested attribute of obj
54
- Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
55
- """
56
- if "." in att:
57
- obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
58
- setattr(obj, att.split(".")[-1], val)
59
-
60
-
61
- def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
62
- """
63
- Stack a list of tensors with padding on one side
64
- Args:
65
- list_of_tensors (list[torch.Tensor]): List of tensors to stack
66
- padding_value (int, optional): Value to pad with. Defaults to 0.
67
- padding_side (str, optional): Side to pad on. Defaults to "right".
68
- Returns:
69
- torch.Tensor: Stacked tensors
70
- """
71
- max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
72
- padded_tensors = []
73
- for tensor in list_of_tensors:
74
- num_tokens = tensor.size(0)
75
- if len(tensor.size()) == 1:
76
- padding = torch.full(
77
- (max_tokens - num_tokens,),
78
- padding_value,
79
- dtype=tensor.dtype,
80
- device=tensor.device,
81
- )
82
- else:
83
- padding = torch.full(
84
- (max_tokens - num_tokens, tensor.size(1)),
85
- padding_value,
86
- dtype=tensor.dtype,
87
- device=tensor.device,
88
- )
89
- padded_tensor = (
90
- torch.cat((tensor, padding), dim=0)
91
- if padding_side == "right"
92
- else torch.cat((padding, tensor), dim=0)
93
- )
94
- padded_tensors.append(padded_tensor)
95
- return torch.stack(padded_tensors)
96
-
97
-
98
- def check_embedding_fns(lang_model):
99
- """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
100
- if not has_fn(lang_model, "get_input_embeddings"):
101
- if hasattr_recursive(lang_model, "transformer.wte"): # MPT
102
- lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
103
- elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
104
- lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
105
- else:
106
- raise ValueError(
107
- "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
108
- )
109
-
110
- if not has_fn(lang_model, "set_input_embeddings"):
111
- if hasattr_recursive(lang_model, "transformer.wte"): # MPT
112
- lang_model.set_input_embeddings = lambda x: setattr_recursive(
113
- lang_model, "transformer.wte", x
114
- )
115
- elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
116
- lang_model.set_input_embeddings = lambda x: setattr_recursive(
117
- lang_model, "model.decoder.embed_tokens", x
118
- )
119
- else:
120
- raise ValueError(
121
- "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
122
- )
123
-
124
- if not has_fn(lang_model, "get_output_embeddings"):
125
- if hasattr_recursive(lang_model, "lm_head"):
126
- lang_model.get_output_embeddings = lambda: lang_model.lm_head
127
- else:
128
- raise ValueError(
129
- "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
130
- )
131
-
132
- if not has_fn(lang_model, "set_output_embeddings"):
133
- if hasattr_recursive(lang_model, "lm_head"):
134
- lang_model.set_output_embeddings = lambda x: setattr_recursive(
135
- lang_model, "lm_head", x
136
- )
137
- else:
138
- raise ValueError(
139
- "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
140
- )
141
-
142
-
143
- def has_fn(model, fn_name):
144
- """Check if model has a function fn_name"""
145
- return callable(getattr(model, fn_name, None))
146
-
147
-
148
- # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
149
- #
150
- # Licensed under the Apache License, Version 2.0 (the "License");
151
- # you may not use this file except in compliance with the License.
152
- # You may obtain a copy of the License at
153
- #
154
- # http://www.apache.org/licenses/LICENSE-2.0
155
- #
156
- # Unless required by applicable law or agreed to in writing, software
157
- # distributed under the License is distributed on an "AS IS" BASIS,
158
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
159
- # See the License for the specific language governing permissions and
160
- # limitations under the License.
161
-
162
- def unpad_image(tensor, original_size, keep_original_shape=False):
163
- """
164
- Unpads a PyTorch tensor of a padded and resized image.
165
-
166
- Args:
167
- tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
168
- original_size (tuple): The original size of the image (height, width).
169
-
170
- Returns:
171
- torch.Tensor: The unpadded image tensor.
172
- """
173
- original_width, original_height = original_size
174
- current_height, current_width = tensor.shape[1:]
175
-
176
- original_aspect_ratio = original_width / original_height
177
- current_aspect_ratio = current_width / current_height
178
-
179
- if original_aspect_ratio > current_aspect_ratio:
180
- scale_factor = current_width / original_width
181
- new_height = int(original_height * scale_factor)
182
- padding = (current_height - new_height) // 2
183
- if keep_original_shape:
184
- attention_mask = torch.ones((current_height, current_width), device=tensor.device)
185
- attention_mask[:padding, :] = 0
186
- attention_mask[current_height - padding:, :] = 0
187
- return tensor, attention_mask
188
- else:
189
- unpadded_tensor = tensor[:, padding:current_height - padding, :]
190
- return unpadded_tensor, None
191
- else:
192
- scale_factor = current_height / original_height
193
- new_width = int(original_width * scale_factor)
194
- padding = (current_width - new_width) // 2
195
- if keep_original_shape:
196
- attention_mask = torch.ones((current_height, current_width), device=tensor.device)
197
- attention_mask[:, :padding] = 0
198
- attention_mask[:, current_width - padding:] = 0
199
- return tensor, attention_mask
200
- else:
201
- unpadded_tensor = tensor[:, :, padding:current_width - padding]
202
- return unpadded_tensor, None
203
-
204
-
205
- def select_best_resolution(original_size, possible_resolutions):
206
- """
207
- Selects the best resolution from a list of possible resolutions based on the original size.
208
-
209
- Args:
210
- original_size (tuple): The original size of the image in the format (width, height).
211
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
212
-
213
- Returns:
214
- tuple: The best fit resolution in the format (width, height).
215
- """
216
- original_width, original_height = original_size
217
- best_fit = None
218
- max_effective_resolution = 0
219
- min_wasted_resolution = float('inf')
220
-
221
- for width, height in possible_resolutions:
222
- scale = min(width / original_width, height / original_height)
223
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
224
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
225
- wasted_resolution = (width * height) - effective_resolution
226
-
227
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
228
- max_effective_resolution = effective_resolution
229
- min_wasted_resolution = wasted_resolution
230
- best_fit = (width, height)
231
-
232
- return best_fit
233
-
234
-
235
- def resize_and_pad_image(image, target_resolution):
236
- """
237
- Resize and pad an image to a target resolution while maintaining aspect ratio.
238
-
239
- Args:
240
- image (PIL.Image.Image): The input image.
241
- target_resolution (tuple): The target resolution (width, height) of the image.
242
-
243
- Returns:
244
- PIL.Image.Image: The resized and padded image.
245
- """
246
- original_width, original_height = image.size
247
- target_width, target_height = target_resolution
248
-
249
- scale_w = target_width / original_width
250
- scale_h = target_height / original_height
251
-
252
- if scale_w < scale_h:
253
- new_width = target_width
254
- new_height = min(math.ceil(original_height * scale_w), target_height)
255
- else:
256
- new_height = target_height
257
- new_width = min(math.ceil(original_width * scale_h), target_width)
258
-
259
- # Resize the image
260
- resized_image = image.resize((new_width, new_height))
261
-
262
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
263
- paste_x = (target_width - new_width) // 2
264
- paste_y = (target_height - new_height) // 2
265
- new_image.paste(resized_image, (paste_x, paste_y))
266
-
267
- return new_image
268
-
269
-
270
- def divide_to_patches(image, patch_size):
271
- """
272
- Divides an image into patches of a specified size.
273
-
274
- Args:
275
- image (PIL.Image.Image): The input image.
276
- patch_size (int): The size of each patch.
277
-
278
- Returns:
279
- list: A list of PIL.Image.Image objects representing the patches.
280
- """
281
- patches = []
282
- width, height = image.size
283
- for i in range(0, height, patch_size):
284
- for j in range(0, width, patch_size):
285
- box = (j, i, j + patch_size, i + patch_size)
286
- patch = image.crop(box)
287
- patches.append(patch)
288
-
289
- return patches
290
-
291
-
292
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
293
- """
294
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
295
-
296
- Args:
297
- image_size (tuple): The size of the input image in the format (width, height).
298
- grid_pinpoints (str): A string representation of a list of possible resolutions.
299
- patch_size (int): The size of each image patch.
300
-
301
- Returns:
302
- tuple: The shape of the image patch grid in the format (width, height).
303
- """
304
- if type(grid_pinpoints) is list:
305
- possible_resolutions = grid_pinpoints
306
- else:
307
- possible_resolutions = ast.literal_eval(grid_pinpoints)
308
- width, height = select_best_resolution(image_size, possible_resolutions)
309
- return width // patch_size, height // patch_size
310
-
311
-
312
- def process_anyres_image(image, processor, grid_pinpoints):
313
- """
314
- Process an image with variable resolutions.
315
-
316
- Args:
317
- image (PIL.Image.Image): The input image to be processed.
318
- processor: The image processor object.
319
- grid_pinpoints (str): A string representation of a list of possible resolutions.
320
-
321
- Returns:
322
- torch.Tensor: A tensor containing the processed image patches.
323
- """
324
- # FIXME: determine grid_pinpoints from image sizes.
325
- if type(grid_pinpoints) is list:
326
- possible_resolutions = grid_pinpoints
327
- else:
328
- possible_resolutions = ast.literal_eval(grid_pinpoints)
329
- best_resolution = select_best_resolution(image.size, possible_resolutions)
330
- image_padded = resize_and_pad_image(image, best_resolution)
331
-
332
- processor_size = processor.transforms[0].size
333
- patches = divide_to_patches(image_padded, processor_size[0])
334
-
335
- image_original_resize = image.resize((processor_size[0], processor_size[0]))
336
-
337
- image_patches = [image_original_resize] + patches
338
- image_patches = [processor(image_patch)
339
- for image_patch in image_patches]
340
- return torch.stack(image_patches, dim=0)
341
-
342
-
343
- def expand2square(pil_img, background_color):
344
- width, height = pil_img.size
345
- if width == height:
346
- return pil_img
347
- elif width > height:
348
- result = Image.new(pil_img.mode, (width, width), background_color)
349
- result.paste(pil_img, (0, (width - height) // 2))
350
- return result
351
- else:
352
- result = Image.new(pil_img.mode, (height, height), background_color)
353
- result.paste(pil_img, ((height - width) // 2, 0))
354
- return result
355
-
356
-
357
- def process_images(images, image_processor, model_cfg):
358
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
359
- new_images = []
360
- if image_aspect_ratio == 'pad':
361
- for image in images:
362
- image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
363
- image = image_processor(image)
364
- new_images.append(image)
365
- elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
366
- base_img_size = image_processor.transforms[0].size[0]
367
- for image in images:
368
- image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
369
- [base_img_size*2,base_img_size],
370
- [base_img_size*2,base_img_size*2],
371
- [base_img_size*3,base_img_size],
372
- [base_img_size,base_img_size*3]])
373
-
374
- # Debug any res inference by only using 672x672.
375
- # image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
376
- new_images.append(image)
377
- else:
378
- return image_processor(images)
379
- if all(x.shape == new_images[0].shape for x in new_images):
380
- new_images = torch.stack(new_images, dim=0)
381
- return new_images
382
-
383
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vlm.py DELETED
@@ -1,1314 +0,0 @@
1
-
2
- import torch
3
- from torch import einsum, nn
4
- from einops import rearrange, repeat
5
- from einops_exts import rearrange_many
6
- from einops import rearrange
7
- from typing import List, Optional, Tuple, Union
8
- import torch.nn.functional as F
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from dataclasses import dataclass
11
- from transformers import CLIPVisionModel
12
- from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
13
-
14
- import transformers
15
- from packaging.version import Version
16
-
17
- from utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
18
-
19
-
20
- class VisionTokenizer(nn.Module):
21
- def __init__(self, dim_media, num_tokens_per_media):
22
- super().__init__()
23
- self.dim_media = dim_media
24
- self.num_tokens_per_media = num_tokens_per_media
25
-
26
- class PerceiverAttention(nn.Module):
27
- def __init__(self, *, dim, dim_head=64, heads=8):
28
- super().__init__()
29
- self.scale = dim_head**-0.5
30
- self.heads = heads
31
- inner_dim = dim_head * heads
32
-
33
- self.norm_media = nn.LayerNorm(dim)
34
- self.norm_latents = nn.LayerNorm(dim)
35
-
36
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
37
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
38
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
39
-
40
- def forward(self, x, latents, vision_attn_masks=None):
41
- """
42
- Args:
43
- x (torch.Tensor): image features
44
- shape (b, T, n1, D)
45
- latent (torch.Tensor): latent features
46
- shape (b, T, n2, D)
47
- """
48
- x = self.norm_media(x)
49
- latents = self.norm_latents(latents)
50
-
51
- h = self.heads
52
-
53
- q = self.to_q(latents)
54
- kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
55
- if vision_attn_masks is not None:
56
- vision_attn_masks = torch.cat((vision_attn_masks,
57
- torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
58
- dim=-1)
59
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
- q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
- q = q * self.scale
62
-
63
- # attention
64
- sim = einsum("... i d, ... j d -> ... i j", q, k)
65
- # Apply vision attention mask here.
66
- # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
67
- if vision_attn_masks is not None:
68
- attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
69
- vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
70
- attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
71
- sim += attn_bias
72
-
73
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
74
- attn = sim.softmax(dim=-1)
75
-
76
-
77
- out = einsum("... i j, ... j d -> ... i d", attn, v)
78
- out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
79
- return self.to_out(out)
80
-
81
-
82
- def FeedForward(dim, mult=4):
83
- inner_dim = int(dim * mult)
84
- return nn.Sequential(
85
- nn.LayerNorm(dim),
86
- nn.Linear(dim, inner_dim, bias=False),
87
- nn.GELU(),
88
- nn.Linear(inner_dim, dim, bias=False),
89
- )
90
-
91
-
92
- class PerceiverResampler(VisionTokenizer):
93
- def __init__(
94
- self,
95
- *,
96
- dim,
97
- dim_inner=None,
98
- depth=6,
99
- dim_head=96,
100
- heads=16,
101
- num_latents=128,
102
- max_num_media=None,
103
- max_num_frames=None,
104
- ff_mult=4,
105
- ):
106
- """
107
- Perceiver module which takes in image features and outputs image tokens.
108
- Args:
109
- dim (int): dimension of the incoming image features
110
- dim_inner (int, optional): final dimension to project the incoming image features to;
111
- also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
112
- depth (int, optional): number of layers. Defaults to 6.
113
- dim_head (int, optional): dimension of each head. Defaults to 64.
114
- heads (int, optional): number of heads. Defaults to 8.
115
- num_latents (int, optional): number of latent tokens to use in the Perceiver;
116
- also corresponds to number of tokens per sequence to output. Defaults to 64.
117
- max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
118
- and keep positional embeddings for. If None, no positional embeddings are used.
119
- max_num_frames (int, optional): maximum number of frames to input into the Perceiver
120
- and keep positional embeddings for. If None, no positional embeddings are used.
121
- ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
122
- """
123
- if dim_inner is not None:
124
- projection = nn.Linear(dim, dim_inner)
125
- else:
126
- projection = None
127
- dim_inner = dim
128
- super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
129
- self.projection = projection
130
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
131
-
132
- # positional embeddings
133
- self.frame_embs = (
134
- nn.Parameter(torch.randn(max_num_frames, dim))
135
- if exists(max_num_frames)
136
- else None
137
- )
138
- self.media_time_embs = (
139
- nn.Parameter(torch.randn(max_num_media, 1, dim))
140
- if exists(max_num_media)
141
- else None
142
- )
143
-
144
- self.layers = nn.ModuleList([])
145
- for _ in range(depth):
146
- self.layers.append(
147
- nn.ModuleList(
148
- [
149
- PerceiverAttention(
150
- dim=dim, dim_head=dim_head, heads=heads
151
- ),
152
- FeedForward(dim=dim, mult=ff_mult),
153
- ]
154
- )
155
- )
156
-
157
- self.norm = nn.LayerNorm(dim)
158
-
159
- def forward(self, x, vision_attn_masks=None):
160
- """
161
- Args:
162
- x (torch.Tensor): image features
163
- shape (b, T, F, v, D)
164
- vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
165
- shape (b, v)
166
- Returns:
167
- shape (b, T, n, D) where n is self.num_latents
168
- """
169
- b, T, F, v = x.shape[:4]
170
-
171
- # frame and media time embeddings
172
- if exists(self.frame_embs):
173
- frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
174
- x = x + frame_embs
175
- x = rearrange(
176
- x, "b T F v d -> b T (F v) d"
177
- ) # flatten the frame and spatial dimensions
178
- if exists(self.media_time_embs):
179
- x = x + self.media_time_embs[:T]
180
-
181
- # blocks
182
- latents = self.latents
183
- latents = repeat(latents, "n d -> b T n d", b=b, T=T)
184
- for attn, ff in self.layers:
185
- latents = attn(x, latents, vision_attn_masks) + latents
186
- latents = ff(latents) + latents
187
-
188
- if exists(self.projection):
189
- return self.projection(self.norm(latents))
190
- else:
191
- return self.norm(latents)
192
-
193
-
194
- class DecoupledEmbedding(nn.Embedding):
195
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
196
- """
197
- Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
198
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
199
- then it will create `num_additional_embeddings` additional parameters that are always trained. If
200
- `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
201
- """
202
-
203
- def __init__(
204
- self,
205
- max_original_id: int,
206
- num_additional_embeddings: int = 0,
207
- _weight: torch.Tensor = None,
208
- num_original_embeddings: int = None,
209
- embedding_dim: int = None,
210
- partially_freeze=True,
211
- device=None,
212
- dtype=None,
213
- pad_token_id=None,
214
- ) -> None:
215
- """
216
- Args:
217
- max_original_id (`int`):
218
- The largest token id that should be embedded using the regular embedding (regular `weight`).
219
- This is usually len(tokenizer) - 1 before additional tokens are added.
220
- Note that this may not equal self.weight.shape[0]
221
- num_additional_embeddings (`int`):
222
- Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
223
- _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
224
- If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
225
- num_original_embeddings (`int`):
226
- self.weight.shape[0]
227
- embedding_dim (`int`):
228
- The size of each embedding vector
229
- partially_freeze: (`bool`, *optional*, defaults to `True`):
230
- If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
231
- padding_idx (`int`, *optional*):
232
- The padding index (needs to be less than num_embeddings)
233
-
234
- Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
235
- `max_norm` or `norm_type`. We are not supporting these.
236
- """
237
- # validate args
238
- if pad_token_id is not None and pad_token_id > max_original_id:
239
- raise ValueError(
240
- f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
241
- + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
242
- )
243
- if _weight is not None:
244
- assert (num_original_embeddings is None) or (
245
- _weight.shape[0] == num_original_embeddings
246
- ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
247
- assert (embedding_dim is None) or (
248
- _weight.shape[1] == embedding_dim
249
- ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
250
- num_original_embeddings = _weight.shape[0]
251
- embedding_dim = _weight.shape[1]
252
- else:
253
- assert (
254
- num_original_embeddings is not None
255
- ), "num_original_embeddings must be provided if _weight is not provided"
256
- assert (
257
- embedding_dim is not None
258
- ), "embedding_dim must be provided if _weight is not provided"
259
-
260
- super().__init__(
261
- num_embeddings=num_original_embeddings,
262
- embedding_dim=embedding_dim,
263
- device=device,
264
- dtype=dtype,
265
- padding_idx=pad_token_id,
266
- _weight=_weight,
267
- )
268
- self.max_original_id = max_original_id
269
- self.padding_idx = pad_token_id
270
- self.num_additional_embeddings = num_additional_embeddings
271
- if self.num_additional_embeddings > 0:
272
- self.additional_embedding = nn.Embedding(
273
- num_embeddings=self.num_additional_embeddings,
274
- embedding_dim=embedding_dim,
275
- device=device,
276
- dtype=dtype,
277
- )
278
- self.set_requires_grad(
279
- require_regular_grad=not partially_freeze, require_additional_grad=True
280
- )
281
-
282
- def set_requires_grad(self, require_regular_grad, require_additional_grad):
283
- """
284
- Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
285
- """
286
- self.weight.requires_grad_(require_regular_grad)
287
- self.additional_embedding.requires_grad_(require_additional_grad)
288
-
289
- def forward(self, input_ids):
290
- """
291
- we have 2 embeddings, with different indices - one pretrained self.weight and another
292
- self.additional_embedding.weight that is being trained.
293
-
294
- in order to make a lookup of the input ids, we:
295
- 1. find out the indices of the entries belonging to the 2nd embedding
296
- 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
297
- embedding starts from 0 and not num_embeddings
298
- 3. perform the 2nd embedding lookup
299
- 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
300
- 5. perform the 1st embedding lookup
301
- 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
302
-
303
- note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
304
- then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
305
- i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
306
- usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
307
- measure.
308
-
309
- """
310
- if self.num_additional_embeddings == 0:
311
- return F.embedding(input_ids, self.weight)
312
-
313
- # Clone so that we don't modify the original input_ids later on
314
- input_ids = input_ids.clone()
315
- additional_vocab_indices = torch.where(input_ids > self.max_original_id)
316
- input_ids_additional_vocab = input_ids[additional_vocab_indices]
317
- additional_embeddings = self.additional_embedding(
318
- input_ids_additional_vocab - self.max_original_id - 1
319
- )
320
-
321
- # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
322
- input_ids[additional_vocab_indices] = 0
323
- full_vector = F.embedding(input_ids, self.weight)
324
-
325
- # overwrite the records with high indices
326
- full_vector[additional_vocab_indices] = additional_embeddings
327
-
328
- return full_vector
329
-
330
- def extra_repr(self) -> str:
331
- return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
332
- self.max_original_id + 1,
333
- self.num_additional_embeddings,
334
- self.embedding_dim,
335
- (not self.weight.requires_grad),
336
- )
337
-
338
-
339
- class DecoupledLinear(nn.Linear):
340
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
341
- """
342
- Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
343
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
344
- then it will create `additional_out_features * in_features` additional parameters that are always trained. If
345
- `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
346
- """
347
-
348
- def __init__(
349
- self,
350
- max_original_id: int,
351
- additional_out_features: int = 0,
352
- _weight: torch.Tensor = None,
353
- _bias: torch.Tensor = None,
354
- in_features: int = None,
355
- original_out_features: int = None,
356
- bias: bool = True,
357
- partially_freeze: bool = True,
358
- device=None,
359
- dtype=None,
360
- ) -> None:
361
- """
362
- Args:
363
- max_original_id (`int`): The largest token id that should be extracted from the regular weight.
364
- This is usually len(tokenizer) - 1 before additional tokens are added.
365
- Note that this may not equal original_out_features - 1
366
- _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
367
- If provided, this sets the `in_features` and `original_out_features` parameters.
368
- _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
369
- in_features: int. Input hidden size.
370
- original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
371
- additional_out_features: int. Number of additional trainable dimensions.
372
- bias: bool. Whether to include a bias term.
373
- partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
374
- """
375
- # argument validation
376
- if _weight is not None:
377
- assert (_weight.shape[0] == original_out_features) or (
378
- original_out_features is None
379
- ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
380
- assert (_weight.shape[1] == in_features) or (
381
- in_features is None
382
- ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
383
- in_features = _weight.shape[1]
384
- original_out_features = _weight.shape[0]
385
- else:
386
- assert (
387
- in_features is not None
388
- ), "in_features must be provided if _weight is not provided"
389
- assert (
390
- original_out_features is not None
391
- ), "original_out_features must be provided if _weight is not provided"
392
-
393
- if _bias is not None:
394
- assert bias is True, "bias must be True if _bias is provided"
395
-
396
- # initialize original linear
397
- super().__init__(
398
- in_features,
399
- original_out_features,
400
- bias,
401
- device,
402
- dtype)
403
-
404
- # set weight and bias manually
405
- if _weight is not None:
406
- self.weight = nn.Parameter(_weight)
407
- if _bias is not None:
408
- self.bias = nn.Parameter(_bias)
409
-
410
- self.in_features = in_features
411
- self.original_out_features = original_out_features
412
- self.max_original_id = max_original_id
413
-
414
- # initialize additional linear
415
- self.additional_out_features = additional_out_features
416
- self.has_bias = bias
417
- if additional_out_features > 0:
418
- self.additional_fc = nn.Linear(
419
- in_features=in_features,
420
- out_features=additional_out_features,
421
- bias=self.has_bias,
422
- device=device,
423
- dtype=dtype,
424
- )
425
- self.set_requires_grad(
426
- require_regular_grad=not partially_freeze, require_additional_grad=True
427
- )
428
-
429
- def set_requires_grad(self, require_regular_grad, require_additional_grad):
430
- """
431
- Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
432
- """
433
- self.weight.requires_grad_(require_regular_grad)
434
- if self.has_bias:
435
- self.bias.requires_grad_(require_regular_grad)
436
- self.additional_fc.requires_grad_(require_additional_grad)
437
-
438
- def forward(self, input: torch.Tensor) -> torch.Tensor:
439
- output = F.linear(input, self.weight, self.bias)
440
- output = output[..., : self.max_original_id + 1]
441
-
442
- if self.additional_out_features > 0:
443
- additional_features = F.linear(
444
- input, self.additional_fc.weight, self.additional_fc.bias
445
- )
446
- output = torch.cat((output, additional_features), -1)
447
- return output
448
-
449
- def extra_repr(self) -> str:
450
- """Overwriting `nn.Linear.extra_repr` to include new parameters."""
451
- return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
452
- self.in_features,
453
- self.max_original_id + 1,
454
- self.additional_out_features,
455
- self.bias is not None,
456
- (not self.weight.requires_grad or not self.bias.requires_grad),
457
- )
458
-
459
- class VLM(nn.Module):
460
- """
461
- Generic vision-language model (VLM) class.
462
- A VLM consists of four components:
463
- 1. A vision encoder that extracts features from pixels, e.g. CLIP
464
- input: (B, T_img, F, C, H, W)
465
- output: (B, T_img, F, v, d)
466
- 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
467
- input: (B, T_img, F, v, d)
468
- output: (B, T_img, n, d)
469
- 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
470
- 4. A language model
471
- """
472
-
473
- def __init__(
474
- self,
475
- vision_encoder: nn.Module,
476
- vision_tokenizer: nn.Module,
477
- lang_model: nn.Module,
478
- initial_tokenizer_len: int,
479
- pad_token_id: int,
480
- gradient_checkpointing: bool = False,
481
- ):
482
- """
483
- Args:
484
- vision_encoder (nn.Module): e.g. CLIP
485
- vision_tokenizer (nn.Module): e.g. PerceiverResampler
486
- lang_model (nn.Module): e.g. MPT
487
- initial_tokenizer_len (int): size of the original tokenizer vocab
488
- pad_token_id (int): id of the pad token
489
- gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
490
- """
491
- super().__init__()
492
-
493
- # save dimension information
494
- self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
495
- if hasattr(lang_model.config, "d_model"):
496
- self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
497
- else:
498
- self.lang_hidden_dim = lang_model.config.hidden_size
499
- self.vis_embedding_dim = vision_tokenizer.dim_media
500
- self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
501
-
502
- # core components
503
- self.vision_encoder = vision_encoder
504
- self.vision_tokenizer = vision_tokenizer
505
- self.lang_model = lang_model
506
-
507
- # lm embeddings
508
- self.pad_token_id = pad_token_id
509
- self.initial_tokenizer_len = initial_tokenizer_len
510
- input_embeds = DecoupledEmbedding(
511
- max_original_id=initial_tokenizer_len - 1,
512
- num_additional_embeddings=len(self.special_tokens),
513
- _weight=self.lang_model.get_input_embeddings().weight,
514
- pad_token_id=self.pad_token_id,
515
- )
516
- if hasattr(input_embeds, "additional_embedding"):
517
- input_embeds.additional_embedding.weight.data.normal_(
518
- mean=0.0,
519
- std=self.lang_model.config.initializer_range
520
- if hasattr(self.lang_model.config, "initializer_range")
521
- else 0.02,
522
- )
523
- self.lang_model.set_input_embeddings(input_embeds)
524
-
525
- out_embeds = DecoupledLinear(
526
- max_original_id=initial_tokenizer_len - 1,
527
- additional_out_features=len(self.special_tokens),
528
- _weight=self.lang_model.get_output_embeddings().weight,
529
- _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
530
- )
531
- if hasattr(out_embeds, "additional_fc"):
532
- out_embeds.additional_fc.weight.data.normal_(
533
- mean=0.0,
534
- std=self.lang_model.config.initializer_range
535
- if hasattr(self.lang_model.config, "initializer_range")
536
- else 0.02,
537
- )
538
- self.lang_model.set_output_embeddings(out_embeds)
539
-
540
- # gradient checkpointing
541
- self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
542
-
543
- def forward(
544
- self,
545
- vision_x: Optional[torch.Tensor],
546
- lang_x: torch.Tensor,
547
- attention_mask: Optional[torch.Tensor] = None,
548
- labels: Optional[torch.Tensor] = None,
549
- past_key_values: Optional[
550
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
551
- ] = None,
552
- past_media_locations: Optional[torch.Tensor] = None,
553
- past_vision_tokens: Optional[torch.Tensor] = None,
554
- use_cache: Optional[bool] = False,
555
- **kwargs,
556
- ):
557
- """
558
- Args:
559
- vision_x: Vision input
560
- shape (B, T_img, F, C, H, W) with F=1
561
- only F = 1 is supported (single-frame videos)
562
- if T_img > the number of media tokens in the corresponding input_ids (lang_x),
563
- only the first number of media tokens in lang_x are used
564
- lang_x: Language input ids, with media tokens denoting where
565
- visual media should be inserted.
566
- shape (B, T_txt)
567
- attention_mask: Attention mask. Defaults to None.
568
- labels: Labels. Defaults to None.
569
- shape (B, T_txt)
570
- past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
571
- list of length = number of decoder layers in the LM
572
- exact implementation depends on LM, see Hugging Face docs
573
- past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
574
- shape (B, T_txt)
575
- past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
576
- use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
577
- If True, includes key_values, media_locations, and vision_tokens in the output.
578
- """
579
- assert not (past_vision_tokens is None) ^ (
580
- past_media_locations is None
581
- ), "past_vision_tokens and past_media_locations must both be None or both be not None"
582
-
583
- # convert pixels to vision tokens
584
- if vision_x is not None:
585
- vision_features = self._encode_vision_x(vision_x=vision_x)
586
- vision_tokens = self.vision_tokenizer(vision_features)
587
- else:
588
- vision_tokens = None
589
-
590
- # fuse the vision and language tokens
591
- new_inputs = self._prepare_inputs_for_forward(
592
- vision_tokens=vision_tokens,
593
- lang_x=lang_x,
594
- attention_mask=attention_mask,
595
- labels=labels,
596
- past_key_values=past_key_values,
597
- past_media_locations=past_media_locations,
598
- padding_side="right",
599
- past_vision_tokens=past_vision_tokens,
600
- )
601
- output = self.lang_model(
602
- **new_inputs,
603
- use_cache=use_cache,
604
- past_key_values=past_key_values,
605
- **kwargs,
606
- )
607
-
608
- # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
609
- # or to add the past_vision_tokens and past_media_locations to the output
610
- output = self._postprocess_outputs_from_forward(
611
- output=output,
612
- lang_x=lang_x,
613
- vision_tokens=vision_tokens,
614
- use_cache=use_cache,
615
- past_vision_tokens=past_vision_tokens,
616
- past_media_locations=past_media_locations,
617
- )
618
-
619
- # postforward hooks
620
- self._post_forward_hook()
621
- return output
622
-
623
- def _encode_vision_x_anyres(self, samples, device):
624
- assert self.anyres_grids is not None
625
- image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
626
- image_sizes = samples["image_size"]
627
-
628
- # Image_raw can be a list of list of patches, when a `samples` has multiple images.
629
- if isinstance(image_raw[0], list):
630
- images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
631
- image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
632
- else:
633
- # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
634
- # concate list of patches into one big patch for any res encoding.
635
- images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
636
- image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
637
- image = image.to(device)
638
-
639
- with torch.no_grad():
640
- if self.vision_encoder.__class__.__name__ == "TimmModel":
641
- image_embeds = self.vision_encoder.trunk.forward_features(image)
642
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
643
- image_embeds = self.vision_encoder(image).last_hidden_state
644
- else:
645
- image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
646
-
647
- if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer):
648
- base_img_size = self.vision_encoder.config.image_size
649
- else:
650
- base_img_size = self.vision_encoder.image_size[0]
651
-
652
- if self.vision_encoder.__class__.__name__ == "TimmModel":
653
- grid_size = self.vision_encoder.trunk.patch_embed.grid_size
654
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
655
- grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
656
- grid_size = (grid_size_base, grid_size_base)
657
- else:
658
- grid_size = self.vision_encoder.grid_size
659
- height, width = grid_size
660
-
661
- if not image_embeds.shape[1] == height * width:
662
- assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
663
- image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
664
- n_vis_token_per_patch = image_embeds.shape[1]
665
-
666
- # Split encoded patches and merge patch features
667
- # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
668
- split_sizes = [image.shape[0] for image in images]
669
- image_embeds = torch.split(image_embeds, split_sizes, dim=0)
670
- # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
671
- new_image_embeds = []
672
- patch_attn_masks = []
673
- max_n_img_token = -1
674
- for idx, patch_embeds in enumerate(image_embeds):
675
- if patch_embeds.shape[0] > 1:
676
- # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
677
- base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
678
- patch_embeds = patch_embeds[1:]
679
-
680
- assert height * width == base_patch_embeds.shape[0]
681
-
682
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
683
- self.anyres_grids,
684
- base_img_size) # Hardcoded grid_pinpoints.
685
- patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
686
-
687
- patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
688
- patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
689
- patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
690
- if hasattr(self, 'image_newline'):
691
- patch_embeds = torch.cat((
692
- patch_embeds,
693
- self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
694
- ), dim=-1)
695
- if self.anyres_patch_sampling:
696
- patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
697
- patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
698
- assert patch_attn_mask is not None
699
- patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
700
- patch_attn_mask = patch_attn_mask.flatten(0, 1)
701
- patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
702
- patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
703
- else:
704
- patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
705
- patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
706
- else:
707
- patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
708
- patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
709
- if hasattr(self, 'image_newline'):
710
- patch_embeds = torch.cat((
711
- patch_embeds,
712
- self.image_newline[None]
713
- ), dim=0)
714
- if not self.anyres_patch_sampling:
715
- max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
716
-
717
- new_image_embeds.append(patch_embeds)
718
- patch_attn_masks.append(patch_attn_mask)
719
-
720
- if self.anyres_patch_sampling:
721
- # Return individual patches for independent token downsampling.
722
- return new_image_embeds, patch_attn_masks
723
-
724
- # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
725
- image_embeds = []
726
- image_atts = []
727
- for image_embed in new_image_embeds:
728
- n_img_token = image_embed.shape[0]
729
- img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
730
- if n_img_token < max_n_img_token:
731
- padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
732
- padded_embed[:n_img_token, :] = image_embed
733
- img_attn[n_img_token:] = 0 # Mask out the padded entries.
734
- else:
735
- padded_embed = image_embed
736
- image_embeds.append(padded_embed)
737
- image_atts.append(img_attn)
738
- image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
739
- image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
740
- # TODO: reshape image_embeds and image_atts to "b T F v d"
741
- image_embeds = image_embeds[:, None, None, :, :]
742
- # image_atts = image_atts[:, None, None, :, :]
743
-
744
- return image_embeds, image_atts
745
-
746
- def _encode_vision_x(self, vision_x: torch.Tensor):
747
- """
748
- Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
749
- Args:
750
- vision_x: Vision input
751
- shape (B, T_img, F, C, H, W)
752
- Images in the same chunk are collated along T_img, and frames are collated along F
753
- Currently only F=1 is supported (single-frame videos)
754
-
755
- rearrange code based on https://github.com/dhansmair/flamingo-mini
756
- """
757
- assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
758
- b, T, F = vision_x.shape[:3]
759
-
760
- vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
761
- with torch.no_grad():
762
- if self.vision_encoder.__class__.__name__ == "TimmModel":
763
- vision_x = self.vision_encoder.trunk.forward_features(vision_x)
764
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
765
- vision_x = self.vision_encoder(vision_x).last_hidden_state
766
- else:
767
- vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
768
- vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
769
- return vision_x
770
-
771
- def _concat_vision_cache(
772
- self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
773
- ):
774
- """
775
- Helper function to include the past vision tokens and past media locations in the output.
776
- """
777
- if use_cache:
778
- if past_media_locations is not None and past_vision_tokens is not None:
779
- if vision_tokens is not None:
780
- updated_vision_tokens = torch.cat(
781
- [
782
- past_vision_tokens,
783
- vision_tokens,
784
- ],
785
- dim=1,
786
- )
787
- else:
788
- updated_vision_tokens = past_vision_tokens
789
- updated_media_locations = torch.cat(
790
- [
791
- past_media_locations,
792
- lang_x == self.media_token_id,
793
- ],
794
- dim=1,
795
- )
796
- else:
797
- updated_vision_tokens = vision_tokens
798
- updated_media_locations = lang_x == self.media_token_id
799
-
800
- else:
801
- updated_vision_tokens = None
802
- updated_media_locations = None
803
-
804
- return updated_vision_tokens, updated_media_locations
805
-
806
- def generate(
807
- self,
808
- vision_x: torch.Tensor,
809
- lang_x: torch.Tensor,
810
- attention_mask: torch.Tensor = None,
811
- past_key_values: Optional[
812
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
813
- ] = None,
814
- past_media_locations: Optional[torch.Tensor] = None,
815
- past_vision_tokens: Optional[torch.Tensor] = None,
816
- **kwargs,
817
- ):
818
- """
819
- Generate text conditioned on vision and language inputs.
820
- Args:
821
- vision_x (torch.Tensor): Vision input
822
- shape (B, T_img, F, C, H, W)
823
- see documentation for forward
824
- lang_x (torch.Tensor): Language input
825
- shape (B, T_txt)
826
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
827
- **kwargs: see generate documentation in Hugging Face CausalLM models.
828
- Returns:
829
- torch.Tensor: lang_x with generated tokens appended to it
830
- """
831
- num_beams = kwargs.pop("num_beams", 1)
832
-
833
- # convert pixels to vision tokens
834
- if vision_x is not None:
835
- vision_features = self._encode_vision_x(vision_x=vision_x)
836
- vision_tokens = self.vision_tokenizer(vision_features)
837
- else:
838
- vision_tokens = None
839
-
840
- # fuse the vision and language tokens
841
- # for xattn, vision_x and media_location are repeat_interleaved s.t.
842
- # the total batch size is B * num_beams
843
- new_inputs = self._prepare_inputs_for_forward(
844
- vision_tokens=vision_tokens,
845
- lang_x=lang_x,
846
- attention_mask=attention_mask,
847
- past_key_values=past_key_values,
848
- past_media_locations=past_media_locations,
849
- past_vision_tokens=past_vision_tokens,
850
- padding_side="left",
851
- num_beams=num_beams,
852
- )
853
- output = self.lang_model.generate(
854
- **new_inputs,
855
- past_key_values=past_key_values,
856
- num_beams=num_beams,
857
- use_cache=True,
858
- **kwargs,
859
- )
860
- self._post_forward_hook()
861
- return output
862
-
863
- @property
864
- def num_trainable_params(self):
865
- """Print the number of trainable parameters"""
866
- return num_params(self, filter_to_trainable=True)
867
-
868
- def set_trainable(self):
869
- """
870
- Freeze appropriate parameters in the model.
871
- """
872
- raise NotImplementedError
873
-
874
- def group_params_by_weight_decay(self):
875
- """
876
- Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
877
- """
878
- params_with_wd, params_without_wd = [], []
879
- for n, p in self.named_parameters():
880
- if p.requires_grad:
881
- if self._should_apply_weight_decay(n):
882
- params_with_wd.append(p)
883
- else:
884
- params_without_wd.append(p)
885
- return params_with_wd, params_without_wd
886
-
887
- def _should_apply_weight_decay(self, parameter_name):
888
- """
889
- Return whether weight decay should be applied to a parameter.
890
- """
891
- raise NotImplementedError
892
-
893
- @property
894
- def special_tokens(self):
895
- """
896
- Returns a dict mapping from the attribute name of a special token to its string format,
897
- e.g. "media_token": "<image>"
898
- """
899
- assert (
900
- "media_token" in self._special_tokens
901
- ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
902
- return self._special_tokens
903
-
904
- @property
905
- def special_token_ids(self):
906
- """
907
- Returns a list of the special token ids
908
- """
909
- return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
910
-
911
- def set_special_token_ids(self, string_to_ids):
912
- """
913
- Args:
914
- string_to_ids (dict): mapping from token string to id
915
- """
916
- assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
917
- for att_name, token_str in self.special_tokens.items():
918
- token_id = string_to_ids[token_str]
919
- setattr(self, f"{att_name}_id", token_id)
920
- setattr(self.lang_model, f"{att_name}_id", token_id)
921
-
922
- def init_gradient_checkpointing(self):
923
- from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
924
- checkpoint_wrapper,
925
- CheckpointWrapper,
926
- CheckpointImpl,
927
- apply_activation_checkpointing,
928
- )
929
- from functools import partial
930
-
931
- non_reentrant_wrapper = partial(
932
- checkpoint_wrapper,
933
- checkpoint_impl=CheckpointImpl.NO_REENTRANT,
934
- )
935
- apply_activation_checkpointing(
936
- self,
937
- checkpoint_wrapper_fn=non_reentrant_wrapper,
938
- check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
939
- and not isinstance(m, CheckpointWrapper),
940
- )
941
-
942
- @dataclass
943
- class VLMOutputWithPast(CausalLMOutputWithPast):
944
- """
945
- VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
946
- past_media_locations: Optional[torch.Tensor] = None,
947
- past_vision_tokens: Optional[torch.Tensor] = None,
948
- """
949
-
950
- past_media_locations: Optional[torch.Tensor] = None
951
- past_vision_tokens: Optional[torch.Tensor] = None
952
-
953
-
954
- def exists(val):
955
- return val is not None
956
-
957
-
958
- def FeedForward(dim, mult=4):
959
- inner_dim = int(dim * mult)
960
- return nn.Sequential(
961
- nn.LayerNorm(dim),
962
- nn.Linear(dim, inner_dim, bias=False),
963
- nn.GELU(),
964
- nn.Linear(inner_dim, dim, bias=False),
965
- )
966
-
967
- class VLMWithLanguageStream(VLM):
968
- """
969
- VLM that fuses modalities by inserting vision tokens directly into the language stream.
970
- """
971
-
972
- def __init__(
973
- self,
974
- vision_encoder: nn.Module,
975
- vision_tokenizer: nn.Module,
976
- lang_model: nn.Module,
977
- initial_tokenizer_len: int,
978
- pad_token_id: int,
979
- decoder_layers_attr_name: str = None,
980
- gradient_checkpointing: bool = False,
981
- ):
982
- super().__init__(
983
- vision_encoder=vision_encoder,
984
- vision_tokenizer=vision_tokenizer,
985
- lang_model=lang_model,
986
- initial_tokenizer_len=initial_tokenizer_len,
987
- pad_token_id=pad_token_id,
988
- gradient_checkpointing=gradient_checkpointing,
989
- )
990
- self.decoder_layers_attr_name = decoder_layers_attr_name
991
- if decoder_layers_attr_name is not None:
992
- for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
993
- block._use_gradient_checkpointing = gradient_checkpointing
994
-
995
- def _prepare_inputs_for_forward(
996
- self,
997
- vision_tokens: torch.Tensor,
998
- lang_x: torch.Tensor,
999
- attention_mask: torch.Tensor,
1000
- labels: torch.Tensor = None,
1001
- past_key_values=None,
1002
- vision_attention_mask: Optional[torch.Tensor] = None,
1003
- past_media_locations: torch.Tensor = None,
1004
- past_vision_tokens: torch.Tensor = None,
1005
- padding_side: str = "left",
1006
- num_beams: int = 1,
1007
- ):
1008
- """
1009
- Insert the vision tokens directly into the language stream/
1010
- This requires us to modify the input_ids, attention_mask, and labels.
1011
- """
1012
- if past_key_values is not None:
1013
- past_len = past_key_values[0][0].shape[2]
1014
- assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1015
- "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1016
- + "Check that you've expanded the attention mask to account for past image tokens."
1017
- )
1018
-
1019
- if vision_tokens is None:
1020
- return {
1021
- "input_ids": lang_x,
1022
- "attention_mask": attention_mask,
1023
- "labels": labels,
1024
- }
1025
-
1026
- # get the language embeddings
1027
- lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1028
-
1029
- # build up the multimodal embeddings
1030
- B = lang_x.shape[0]
1031
- has_labels = labels is not None
1032
- multimodal_embeds = []
1033
- multimodal_attention_mask = []
1034
- multimodal_labels = [] if has_labels else None
1035
- for i in range(B):
1036
- # get index of <image> tokens in lang_x[i]
1037
- image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1038
-
1039
- if len(image_token_idxs) == 0:
1040
- multimodal_embeds.append(lang_embeds[i].clone())
1041
- multimodal_attention_mask.append(attention_mask[i].clone())
1042
- if has_labels:
1043
- multimodal_labels.append(labels[i].clone())
1044
- continue
1045
-
1046
- # loop through the image_token_idxs and insert the vision tokens
1047
- new_embed = lang_embeds[i].clone()
1048
- new_attention_mask = (
1049
- attention_mask[i].clone() if attention_mask is not None else None
1050
- )
1051
- if has_labels:
1052
- new_label = labels[i].clone()
1053
- print(vision_tokens.shape)
1054
- for img_num, img_idx in enumerate(image_token_idxs):
1055
- new_embed = torch.cat(
1056
- (
1057
- new_embed[:img_idx],
1058
- vision_tokens[i][img_num],
1059
- new_embed[img_idx + self.num_tokens_per_vis :],
1060
- ),
1061
- dim=0,
1062
- )
1063
- new_attention_mask = torch.cat(
1064
- (
1065
- new_attention_mask[:img_idx],
1066
- torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1067
- attention_mask.device
1068
- ),
1069
- new_attention_mask[img_idx + self.num_tokens_per_vis :],
1070
- ),
1071
- dim=0,
1072
- )
1073
- if has_labels:
1074
- new_label = torch.cat(
1075
- (
1076
- new_label[:img_idx],
1077
- torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1078
- labels.device
1079
- )
1080
- * -100,
1081
- new_label[img_idx + self.num_tokens_per_vis :],
1082
- ),
1083
- dim=0,
1084
- )
1085
- multimodal_embeds.append(new_embed)
1086
- multimodal_attention_mask.append(new_attention_mask)
1087
- if has_labels:
1088
- multimodal_labels.append(new_label)
1089
-
1090
- # stack
1091
- multimodal_embeds = stack_with_padding(
1092
- multimodal_embeds,
1093
- padding_value=self.pad_token_id,
1094
- padding_side=padding_side,
1095
- )
1096
- multimodal_attention_mask = stack_with_padding(
1097
- multimodal_attention_mask,
1098
- padding_value=0,
1099
- padding_side=padding_side,
1100
- )
1101
- if has_labels:
1102
- multimodal_labels = stack_with_padding(
1103
- multimodal_labels,
1104
- padding_value=-100,
1105
- padding_side=padding_side,
1106
- )
1107
-
1108
- return {
1109
- "inputs_embeds": multimodal_embeds,
1110
- "attention_mask": multimodal_attention_mask,
1111
- "labels": multimodal_labels,
1112
- }
1113
-
1114
- def _postprocess_outputs_from_forward(
1115
- self,
1116
- output: CausalLMOutputWithPast,
1117
- lang_x: torch.Tensor,
1118
- vision_tokens: torch.Tensor,
1119
- past_vision_tokens: torch.Tensor,
1120
- past_media_locations: torch.Tensor,
1121
- use_cache: bool = False,
1122
- ):
1123
- # Include the past vision tokens and past media locations in the output
1124
- updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1125
- lang_x=lang_x,
1126
- vision_tokens=vision_tokens,
1127
- past_vision_tokens=past_vision_tokens,
1128
- past_media_locations=past_media_locations,
1129
- use_cache=use_cache,
1130
- )
1131
-
1132
- # return logits that are the same shape as the original input_ids
1133
- logits = output.logits
1134
- batch_logits = []
1135
- B, T_txt = lang_x.shape
1136
- for i in range(B):
1137
- sequence_logits = []
1138
- logits_j = 0
1139
- for j in range(T_txt):
1140
- if lang_x[i, j] != self.media_token_id:
1141
- sequence_logits.append(logits[i, logits_j])
1142
- logits_j += 1
1143
- else:
1144
- # append the logit for the first image token, then skip over the rest
1145
- # note: the model actually learns to predict <im_patch>, not <image>
1146
- sequence_logits.append(logits[i, logits_j])
1147
- logits_j += self.num_tokens_per_vis
1148
- sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1149
- batch_logits.append(sequence_logits)
1150
-
1151
- batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1152
- # The final logits shape should be the same as the original input_ids shape
1153
- assert batch_logits.shape[:2] == (B, T_txt)
1154
-
1155
- # assemble the output
1156
- output = VLMOutputWithPast(
1157
- loss=output.loss,
1158
- logits=batch_logits,
1159
- past_key_values=output.past_key_values,
1160
- hidden_states=output.hidden_states,
1161
- attentions=output.attentions,
1162
- past_media_locations=updated_media_locations,
1163
- past_vision_tokens=updated_vision_tokens,
1164
- )
1165
-
1166
- return output
1167
-
1168
- def _post_forward_hook(self):
1169
- pass
1170
-
1171
-
1172
- @property
1173
- def num_params_per_module(self):
1174
- """Print the number of parameters per module in the model"""
1175
- return "\n".join(
1176
- [
1177
- f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1178
- f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1179
- f"Language model: {num_params(self.lang_model):,} parameters",
1180
- ]
1181
- )
1182
-
1183
- @property
1184
- def num_trainable_params_per_module(self):
1185
- """Print the number of trainable parameters per module in the model"""
1186
- return "\n".join(
1187
- [
1188
- f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1189
- f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1190
- f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1191
- ]
1192
- )
1193
-
1194
-
1195
- class XGenMMPerceiver(VLMWithLanguageStream):
1196
- def __init__(
1197
- self,
1198
- vision_encoder: nn.Module,
1199
- vision_tokenizer: nn.Module,
1200
- lang_model: nn.Module,
1201
- initial_tokenizer_len: int,
1202
- pad_token_id: int,
1203
- decoder_layers_attr_name: str = None,
1204
- gradient_checkpointing: bool = False,
1205
- image_aspect_ratio: str = 'none',
1206
- ):
1207
- """
1208
- Args:
1209
- vision_encoder (nn.Module): HF CLIPModel
1210
- lang_encoder (nn.Module): HF causal language model
1211
- vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1212
- initial_tokenizer_len (int): size of the tokenizer vocab
1213
- padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1214
- will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1215
- decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1216
- gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1217
- """
1218
- self._special_tokens = {
1219
- "media_token": "<image>",
1220
- "image_placeholder_token": "<image placeholder>",
1221
- "end_of_trunk_token": "<|endofchunk|>",
1222
- }
1223
- lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1224
- super().__init__(
1225
- vision_encoder=vision_encoder,
1226
- vision_tokenizer=vision_tokenizer,
1227
- lang_model=lang_model,
1228
- initial_tokenizer_len=initial_tokenizer_len,
1229
- gradient_checkpointing=gradient_checkpointing,
1230
- decoder_layers_attr_name=decoder_layers_attr_name,
1231
- pad_token_id=pad_token_id,
1232
- )
1233
- self.image_aspect_ratio = image_aspect_ratio
1234
-
1235
- def set_trainable(self):
1236
- """
1237
- Unfreeze everything except the vision_encoder
1238
- """
1239
- self.requires_grad_(True)
1240
- self.vision_encoder.requires_grad_(False)
1241
-
1242
- def _should_apply_weight_decay(self, parameter_name):
1243
- """
1244
- Kosmos applies 0.01 weight deacy to everything
1245
- """
1246
- return True
1247
-
1248
- def generate(
1249
- self,
1250
- vision_x: torch.Tensor,
1251
- lang_x: torch.Tensor,
1252
- image_size: Optional[Tuple] = None,
1253
- attention_mask: torch.Tensor = None,
1254
- past_key_values: Optional[
1255
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1256
- ] = None,
1257
- past_media_locations: Optional[torch.Tensor] = None,
1258
- past_vision_tokens: Optional[torch.Tensor] = None,
1259
- **kwargs,
1260
- ):
1261
- """
1262
- Generate text conditioned on vision and language inputs.
1263
- Args:
1264
- vision_x (torch.Tensor): Vision input
1265
- shape (B, T_img, F, C, H, W)
1266
- see documentation for forward
1267
- lang_x (torch.Tensor): Language input
1268
- shape (B, T_txt)
1269
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1270
- **kwargs: see generate documentation in Hugging Face CausalLM models.
1271
- Returns:
1272
- torch.Tensor: lang_x with generated tokens appended to it
1273
- """
1274
- num_beams = kwargs.pop("num_beams", 1)
1275
-
1276
- # convert pixels to vision tokens
1277
- vision_attention_mask = None
1278
- if vision_x is not None:
1279
- vision_features = self._encode_vision_x(vision_x=vision_x)
1280
- vision_tokens = self.vision_tokenizer(vision_features)
1281
- else:
1282
- vision_tokens = None
1283
-
1284
- # fuse the vision and language tokens
1285
- # for xattn, vision_x and media_location are repeat_interleaved s.t.
1286
- # the total batch size is B * num_beams
1287
- new_inputs = self._prepare_inputs_for_forward(
1288
- vision_tokens=vision_tokens,
1289
- lang_x=lang_x,
1290
- attention_mask=attention_mask,
1291
- vision_attention_mask=vision_attention_mask,
1292
- past_key_values=past_key_values,
1293
- past_media_locations=past_media_locations,
1294
- past_vision_tokens=past_vision_tokens,
1295
- padding_side="left",
1296
- num_beams=num_beams,
1297
- )
1298
- if past_key_values is not None:
1299
- output = self.lang_model.generate(
1300
- **new_inputs,
1301
- past_key_values=past_key_values,
1302
- num_beams=num_beams,
1303
- use_cache=True,
1304
- **kwargs,
1305
- )
1306
- else:
1307
- output = self.lang_model.generate(
1308
- **new_inputs,
1309
- num_beams=num_beams,
1310
- use_cache=True,
1311
- **kwargs,
1312
- )
1313
- self._post_forward_hook()
1314
- return output