ucaslcl commited on
Commit
a31b327
1 Parent(s): 74d5f2d

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +38 -0
  2. got_vision_b.py +468 -0
  3. modeling_GOT.py +659 -0
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ucaslcl/GOT-OCR2_0",
3
+ "architectures": [
4
+ "GOTQwenForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_GOT.GOTConfig",
8
+ "AutoModel": "modeling_GOT.GOTQwenForCausalLM"
9
+ },
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 151643,
12
+ "eos_token_id": 151643,
13
+ "freeze_vision_tower": false,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 1024,
16
+ "im_end_token": 151858,
17
+ "im_patch_token": 151859,
18
+ "im_start_token": 151857,
19
+ "image_token_len": 256,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 2816,
22
+ "max_position_embeddings": 32768,
23
+ "max_window_layers": 21,
24
+ "model_type": "GOT",
25
+ "num_attention_heads": 16,
26
+ "num_hidden_layers": 24,
27
+ "num_key_value_heads": 16,
28
+ "rms_norm_eps": 1e-06,
29
+ "rope_theta": 1000000.0,
30
+ "sliding_window": 32768,
31
+ "tie_word_embeddings": true,
32
+ "torch_dtype": "bfloat16",
33
+ "transformers_version": "4.37.2",
34
+ "use_cache": true,
35
+ "use_im_start_end": true,
36
+ "use_sliding_window": false,
37
+ "vocab_size": 151860
38
+ }
got_vision_b.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional, Tuple, Type
4
+ from functools import partial
5
+ import torch.nn as nn
6
+ from typing import Type
7
+
8
+
9
+
10
+ class MLPBlock(nn.Module):
11
+ def __init__(
12
+ self,
13
+ embedding_dim: int,
14
+ mlp_dim: int,
15
+ act: Type[nn.Module] = nn.GELU,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
19
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
20
+ self.act = act()
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.lin2(self.act(self.lin1(x)))
24
+
25
+
26
+
27
+ class LayerNorm2d(nn.Module):
28
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
29
+ super().__init__()
30
+ self.weight = nn.Parameter(torch.ones(num_channels))
31
+ self.bias = nn.Parameter(torch.zeros(num_channels))
32
+ self.eps = eps
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ u = x.mean(1, keepdim=True)
36
+ s = (x - u).pow(2).mean(1, keepdim=True)
37
+ x = (x - u) / torch.sqrt(s + self.eps)
38
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
39
+ return x
40
+
41
+
42
+
43
+ class ImageEncoderViT(nn.Module):
44
+ def __init__(
45
+ self,
46
+ img_size: int = 1024,
47
+ patch_size: int = 16,
48
+ in_chans: int = 3,
49
+ embed_dim: int = 768,
50
+ depth: int = 12,
51
+ num_heads: int = 12,
52
+ mlp_ratio: float = 4.0,
53
+ out_chans: int = 256,
54
+ qkv_bias: bool = True,
55
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
56
+ act_layer: Type[nn.Module] = nn.GELU,
57
+ use_abs_pos: bool = True,
58
+ use_rel_pos: bool = False,
59
+ rel_pos_zero_init: bool = True,
60
+ window_size: int = 0,
61
+ global_attn_indexes: Tuple[int, ...] = (),
62
+ ) -> None:
63
+ """
64
+ Args:
65
+ img_size (int): Input image size.
66
+ patch_size (int): Patch size.
67
+ in_chans (int): Number of input image channels.
68
+ embed_dim (int): Patch embedding dimension.
69
+ depth (int): Depth of ViT.
70
+ num_heads (int): Number of attention heads in each ViT block.
71
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
72
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
73
+ norm_layer (nn.Module): Normalization layer.
74
+ act_layer (nn.Module): Activation layer.
75
+ use_abs_pos (bool): If True, use absolute positional embeddings.
76
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
77
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
78
+ window_size (int): Window size for window attention blocks.
79
+ global_attn_indexes (list): Indexes for blocks using global attention.
80
+ """
81
+ super().__init__()
82
+ self.img_size = img_size
83
+
84
+ self.patch_embed = PatchEmbed(
85
+ kernel_size=(patch_size, patch_size),
86
+ stride=(patch_size, patch_size),
87
+ in_chans=in_chans,
88
+ embed_dim=embed_dim,
89
+ )
90
+
91
+ self.pos_embed: Optional[nn.Parameter] = None
92
+ if use_abs_pos:
93
+ # Initialize absolute positional embedding with pretrain image size.
94
+ self.pos_embed = nn.Parameter(
95
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
96
+ )
97
+
98
+ self.blocks = nn.ModuleList()
99
+ for i in range(depth):
100
+ block = Block(
101
+ dim=embed_dim,
102
+ num_heads=num_heads,
103
+ mlp_ratio=mlp_ratio,
104
+ qkv_bias=qkv_bias,
105
+ norm_layer=norm_layer,
106
+ act_layer=act_layer,
107
+ use_rel_pos=use_rel_pos,
108
+ rel_pos_zero_init=rel_pos_zero_init,
109
+ window_size=window_size if i not in global_attn_indexes else 0,
110
+ input_size=(img_size // patch_size, img_size // patch_size),
111
+ )
112
+ self.blocks.append(block)
113
+
114
+ self.neck = nn.Sequential(
115
+ nn.Conv2d(
116
+ embed_dim,
117
+ out_chans,
118
+ kernel_size=1,
119
+ bias=False,
120
+ ),
121
+ LayerNorm2d(out_chans),
122
+ nn.Conv2d(
123
+ out_chans,
124
+ out_chans,
125
+ kernel_size=3,
126
+ padding=1,
127
+ bias=False,
128
+ ),
129
+ LayerNorm2d(out_chans),
130
+ )
131
+
132
+
133
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
134
+ self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ x = self.patch_embed(x)
138
+ if self.pos_embed is not None:
139
+ x = x + self.pos_embed
140
+
141
+ for blk in self.blocks:
142
+ x = blk(x)
143
+
144
+ x = self.neck(x.permute(0, 3, 1, 2))
145
+ x = self.net_2(x)
146
+ x = self.net_3(x)
147
+
148
+
149
+ return x
150
+
151
+
152
+ class Block(nn.Module):
153
+ """Transformer blocks with support of window attention and residual propagation blocks"""
154
+
155
+ def __init__(
156
+ self,
157
+ dim: int,
158
+ num_heads: int,
159
+ mlp_ratio: float = 4.0,
160
+ qkv_bias: bool = True,
161
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
162
+ act_layer: Type[nn.Module] = nn.GELU,
163
+ use_rel_pos: bool = False,
164
+ rel_pos_zero_init: bool = True,
165
+ window_size: int = 0,
166
+ input_size: Optional[Tuple[int, int]] = None,
167
+ ) -> None:
168
+ """
169
+ Args:
170
+ dim (int): Number of input channels.
171
+ num_heads (int): Number of attention heads in each ViT block.
172
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
173
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
174
+ norm_layer (nn.Module): Normalization layer.
175
+ act_layer (nn.Module): Activation layer.
176
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
177
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
178
+ window_size (int): Window size for window attention blocks. If it equals 0, then
179
+ use global attention.
180
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
181
+ positional parameter size.
182
+ """
183
+ super().__init__()
184
+ self.norm1 = norm_layer(dim)
185
+ self.attn = Attention(
186
+ dim,
187
+ num_heads=num_heads,
188
+ qkv_bias=qkv_bias,
189
+ use_rel_pos=use_rel_pos,
190
+ rel_pos_zero_init=rel_pos_zero_init,
191
+ input_size=input_size if window_size == 0 else (window_size, window_size),
192
+ )
193
+
194
+ self.norm2 = norm_layer(dim)
195
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
196
+
197
+ self.window_size = window_size
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ shortcut = x
201
+ x = self.norm1(x)
202
+ # Window partition
203
+ if self.window_size > 0:
204
+ H, W = x.shape[1], x.shape[2]
205
+ x, pad_hw = window_partition(x, self.window_size)
206
+
207
+ x = self.attn(x)
208
+ # Reverse window partition
209
+ if self.window_size > 0:
210
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
211
+
212
+ x = shortcut + x
213
+ x = x + self.mlp(self.norm2(x))
214
+
215
+ return x
216
+
217
+
218
+ class Attention(nn.Module):
219
+ """Multi-head Attention block with relative position embeddings."""
220
+
221
+ def __init__(
222
+ self,
223
+ dim: int,
224
+ num_heads: int = 8,
225
+ qkv_bias: bool = True,
226
+ use_rel_pos: bool = False,
227
+ rel_pos_zero_init: bool = True,
228
+ input_size: Optional[Tuple[int, int]] = None,
229
+ ) -> None:
230
+ """
231
+ Args:
232
+ dim (int): Number of input channels.
233
+ num_heads (int): Number of attention heads.
234
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
235
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
236
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
237
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
238
+ positional parameter size.
239
+ """
240
+ super().__init__()
241
+ self.num_heads = num_heads
242
+ head_dim = dim // num_heads
243
+ self.scale = head_dim**-0.5
244
+
245
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
246
+ self.proj = nn.Linear(dim, dim)
247
+
248
+ self.use_rel_pos = use_rel_pos
249
+ if self.use_rel_pos:
250
+ assert (
251
+ input_size is not None
252
+ ), "Input size must be provided if using relative positional encoding."
253
+ # initialize relative positional embeddings
254
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
255
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
256
+
257
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
258
+ B, H, W, _ = x.shape
259
+ # qkv with shape (3, B, nHead, H * W, C)
260
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
261
+ # q, k, v with shape (B * nHead, H * W, C)
262
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
263
+
264
+ attn = (q * self.scale) @ k.transpose(-2, -1)
265
+
266
+ if self.use_rel_pos:
267
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
268
+
269
+ attn = attn.softmax(dim=-1)
270
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
271
+ x = self.proj(x)
272
+
273
+ return x
274
+
275
+
276
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
277
+ """
278
+ Partition into non-overlapping windows with padding if needed.
279
+ Args:
280
+ x (tensor): input tokens with [B, H, W, C].
281
+ window_size (int): window size.
282
+
283
+ Returns:
284
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
285
+ (Hp, Wp): padded height and width before partition
286
+ """
287
+ B, H, W, C = x.shape
288
+
289
+ pad_h = (window_size - H % window_size) % window_size
290
+ pad_w = (window_size - W % window_size) % window_size
291
+ if pad_h > 0 or pad_w > 0:
292
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
293
+ Hp, Wp = H + pad_h, W + pad_w
294
+
295
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
296
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
297
+ return windows, (Hp, Wp)
298
+
299
+
300
+ def window_unpartition(
301
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
302
+ ) -> torch.Tensor:
303
+ """
304
+ Window unpartition into original sequences and removing padding.
305
+ Args:
306
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
307
+ window_size (int): window size.
308
+ pad_hw (Tuple): padded height and width (Hp, Wp).
309
+ hw (Tuple): original height and width (H, W) before padding.
310
+
311
+ Returns:
312
+ x: unpartitioned sequences with [B, H, W, C].
313
+ """
314
+ Hp, Wp = pad_hw
315
+ H, W = hw
316
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
317
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
318
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
319
+
320
+ if Hp > H or Wp > W:
321
+ x = x[:, :H, :W, :].contiguous()
322
+ return x
323
+
324
+
325
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
326
+ """
327
+ Get relative positional embeddings according to the relative positions of
328
+ query and key sizes.
329
+ Args:
330
+ q_size (int): size of query q.
331
+ k_size (int): size of key k.
332
+ rel_pos (Tensor): relative position embeddings (L, C).
333
+
334
+ Returns:
335
+ Extracted positional embeddings according to relative positions.
336
+ """
337
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
338
+ # Interpolate rel pos if needed.
339
+ if rel_pos.shape[0] != max_rel_dist:
340
+ # Interpolate rel pos.
341
+ rel_pos_resized = F.interpolate(
342
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
343
+ size=max_rel_dist,
344
+ mode="linear",
345
+ )
346
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
347
+ else:
348
+ rel_pos_resized = rel_pos
349
+
350
+ # Scale the coords with short length if shapes for q and k are different.
351
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
352
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
353
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
354
+
355
+ return rel_pos_resized[relative_coords.long()]
356
+
357
+
358
+ def add_decomposed_rel_pos(
359
+ attn: torch.Tensor,
360
+ q: torch.Tensor,
361
+ rel_pos_h: torch.Tensor,
362
+ rel_pos_w: torch.Tensor,
363
+ q_size: Tuple[int, int],
364
+ k_size: Tuple[int, int],
365
+ ) -> torch.Tensor:
366
+ """
367
+ Args:
368
+ attn (Tensor): attention map.
369
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
370
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
371
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
372
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
373
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
374
+
375
+ Returns:
376
+ attn (Tensor): attention map with added relative positional embeddings.
377
+ """
378
+ q_h, q_w = q_size
379
+ k_h, k_w = k_size
380
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
381
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
382
+
383
+ B, _, dim = q.shape
384
+ r_q = q.reshape(B, q_h, q_w, dim)
385
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
386
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
387
+
388
+ attn = (
389
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
390
+ ).view(B, q_h * q_w, k_h * k_w)
391
+
392
+ return attn
393
+
394
+
395
+ class PatchEmbed(nn.Module):
396
+ """
397
+ Image to Patch Embedding.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ kernel_size: Tuple[int, int] = (16, 16),
403
+ stride: Tuple[int, int] = (16, 16),
404
+ padding: Tuple[int, int] = (0, 0),
405
+ in_chans: int = 3,
406
+ embed_dim: int = 768,
407
+ ) -> None:
408
+ """
409
+ Args:
410
+ kernel_size (Tuple): kernel size of the projection layer.
411
+ stride (Tuple): stride of the projection layer.
412
+ padding (Tuple): padding size of the projection layer.
413
+ in_chans (int): Number of input image channels.
414
+ embed_dim (int): Patch embedding dimension.
415
+ """
416
+ super().__init__()
417
+
418
+ self.proj = nn.Conv2d(
419
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ x = self.proj(x)
424
+ # B C H W -> B H W C
425
+ x = x.permute(0, 2, 3, 1)
426
+ return x
427
+
428
+
429
+
430
+ def build_GOT_vit_b(checkpoint=None):
431
+ return _build_GOT_vision(
432
+ encoder_embed_dim=768,
433
+ encoder_depth=12,
434
+ encoder_num_heads=12,
435
+ encoder_global_attn_indexes=[2, 5, 8, 11],
436
+ checkpoint=checkpoint,
437
+ )
438
+
439
+
440
+ def _build_GOT_vision(
441
+ encoder_embed_dim,
442
+ encoder_depth,
443
+ encoder_num_heads,
444
+ encoder_global_attn_indexes,
445
+ checkpoint=None,
446
+ ):
447
+ prompt_embed_dim = 256
448
+ image_size = 1024
449
+ vit_patch_size = 16
450
+ image_embedding_size = image_size // vit_patch_size
451
+ image_encoder=ImageEncoderViT(
452
+ depth=encoder_depth,
453
+ embed_dim=encoder_embed_dim,
454
+ img_size=image_size,
455
+ mlp_ratio=4,
456
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
457
+ num_heads=encoder_num_heads,
458
+ patch_size=vit_patch_size,
459
+ qkv_bias=True,
460
+ use_rel_pos=True,
461
+ global_attn_indexes=encoder_global_attn_indexes,
462
+ window_size=14,
463
+ out_chans=prompt_embed_dim,
464
+ )
465
+
466
+
467
+ return image_encoder
468
+
modeling_GOT.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
2
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.cache_utils import Cache
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from .got_vision_b import build_GOT_vit_b
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import dataclasses
15
+
16
+
17
+ DEFAULT_IMAGE_TOKEN = "<image>"
18
+ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
+ DEFAULT_IM_START_TOKEN = '<img>'
20
+ DEFAULT_IM_END_TOKEN = '</img>'
21
+
22
+ from enum import auto, Enum
23
+ class SeparatorStyle(Enum):
24
+ """Different separator style."""
25
+ SINGLE = auto()
26
+ TWO = auto()
27
+ MPT = auto()
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Conversation:
32
+ """A class that keeps all conversation history."""
33
+ system: str
34
+ roles: List[str]
35
+ messages: List[List[str]]
36
+ offset: int
37
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
38
+ sep: str = "<|im_end|>"
39
+ sep2: str = None
40
+ version: str = "Unknown"
41
+
42
+ skip_next: bool = False
43
+
44
+ def get_prompt(self):
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep + '\n'
47
+ for role, message in self.messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(self.messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ return ret
66
+ if self.sep_style == SeparatorStyle.MPT:
67
+ if self.system:
68
+ ret = self.system + self.sep
69
+ else:
70
+ ret = ''
71
+ for role, message in self.messages:
72
+ if message:
73
+ if type(message) is tuple:
74
+ message, _, _ = message
75
+ ret += role + message + self.sep
76
+ else:
77
+ ret += role
78
+ return ret
79
+ else:
80
+ raise ValueError(f"Invalid style: {self.sep_style}")
81
+
82
+
83
+ def append_message(self, role, message):
84
+ self.messages.append([role, message])
85
+
86
+ def copy(self):
87
+ return Conversation(
88
+ system=self.system,
89
+ roles=self.roles,
90
+ messages=[[x, y] for x, y in self.messages],
91
+ offset=self.offset,
92
+ sep_style=self.sep_style,
93
+ sep=self.sep,
94
+ sep2=self.sep2)
95
+
96
+
97
+
98
+ class KeywordsStoppingCriteria(StoppingCriteria):
99
+ def __init__(self, keywords, tokenizer, input_ids):
100
+ self.keywords = keywords
101
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
102
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
103
+ self.tokenizer = tokenizer
104
+ self.start_len = None
105
+ self.input_ids = input_ids
106
+
107
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
108
+ if self.start_len is None:
109
+ self.start_len = self.input_ids.shape[1]
110
+ else:
111
+ for keyword_id in self.keyword_ids:
112
+ if output_ids[0, -1] == keyword_id:
113
+ return True
114
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
115
+ for keyword in self.keywords:
116
+ if keyword in outputs:
117
+ return True
118
+ return False
119
+
120
+
121
+ class GOTImageEvalProcessor:
122
+ def __init__(self, image_size=384, mean=None, std=None):
123
+ if mean is None:
124
+ mean = (0.48145466, 0.4578275, 0.40821073)
125
+ if std is None:
126
+ std = (0.26862954, 0.26130258, 0.27577711)
127
+
128
+ self.normalize = transforms.Normalize(mean, std)
129
+
130
+ self.transform = transforms.Compose(
131
+ [
132
+ transforms.Resize(
133
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
134
+ ),
135
+ transforms.ToTensor(),
136
+ self.normalize,
137
+ ]
138
+ )
139
+ def __call__(self, item):
140
+ return self.transform(item)
141
+
142
+
143
+
144
+ class GOTConfig(Qwen2Config):
145
+ model_type = "GOT"
146
+
147
+
148
+ class GOTQwenModel(Qwen2Model):
149
+ config_class = GOTConfig
150
+
151
+ def __init__(self, config: Qwen2Config):
152
+ super(GOTQwenModel, self).__init__(config)
153
+
154
+ self.vision_tower_high = build_GOT_vit_b()
155
+
156
+ self.mm_projector_vary = nn.Linear(1024, 1024)
157
+
158
+
159
+ def initialize_vision_modules(
160
+ self,
161
+ vision_tower,
162
+ pretrained_stage1_model=None,
163
+ freeze_vision_tower=False,
164
+ use_im_start_end=False,
165
+ vision_select_layer=-1,
166
+ dtype=torch.float16,
167
+ device="cuda"
168
+ ):
169
+
170
+
171
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
+
173
+ self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
+
175
+ self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
176
+
177
+
178
+ image_token_len = 256
179
+
180
+ self.config.vision_tower = vision_tower
181
+ self.config.image_token_len = image_token_len
182
+
183
+ self.config.use_im_start_end = True
184
+
185
+ self.config.vision_select_layer = vision_select_layer
186
+ self.config.freeze_vision_tower = freeze_vision_tower
187
+
188
+ return dict(
189
+ image_processor_high=image_processor_high,
190
+ image_token_len=image_token_len,
191
+ )
192
+
193
+
194
+ def forward(
195
+ self,
196
+ input_ids: torch.LongTensor = None,
197
+ attention_mask: Optional[torch.Tensor] = None,
198
+ position_ids: Optional[torch.LongTensor] = None,
199
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
200
+ inputs_embeds: Optional[torch.FloatTensor] = None,
201
+ use_cache: Optional[bool] = None,
202
+ output_attentions: Optional[bool] = None,
203
+ output_hidden_states: Optional[bool] = None,
204
+ images: Optional[torch.FloatTensor] = None,
205
+ return_dict: Optional[bool] = None,
206
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
207
+
208
+ # HACK: replace back original embeddings for LLaVA pretraining
209
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
210
+ if orig_embeds_params is not None:
211
+ with torch.no_grad():
212
+ self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
213
+
214
+ if inputs_embeds is None:
215
+ inputs_embeds = self.embed_tokens(input_ids)
216
+
217
+
218
+ vision_tower_high = getattr(self, 'vision_tower_high', None)
219
+
220
+
221
+ if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
222
+ use_im_start_end = getattr(self.config, "use_im_start_end", -1)
223
+
224
+ vision_select_layer = getattr(self.config, "vision_select_layer", -1)
225
+ im_patch_token = getattr(self.config, "im_patch_token", -1)
226
+ im_start_token = getattr(self.config, "im_start_token", -1)
227
+ im_end_token = getattr(self.config, "im_end_token", -1)
228
+ freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
229
+
230
+ im_patch_token = 151859
231
+
232
+ im_start_token = 151857
233
+
234
+ im_end_token = 151858
235
+
236
+ image_features = []
237
+
238
+
239
+ for image in images:
240
+ P, C, H, W = image.shape
241
+ if P == 1:
242
+ with torch.set_grad_enabled(False):
243
+ cnn_feature = vision_tower_high(image)
244
+ cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
245
+ image_feature = self.mm_projector_vary(cnn_feature)
246
+ image_features.append(image_feature)
247
+
248
+ else:
249
+ image_patches = torch.unbind(image)
250
+ image_patches_features = []
251
+ for image_patch in image_patches:
252
+ image_p = torch.stack([image_patch])
253
+ with torch.set_grad_enabled(False):
254
+ cnn_feature_p = vision_tower_high(image_p)
255
+ cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
256
+ image_feature_p = self.mm_projector_vary(cnn_feature_p)
257
+ image_patches_features.append(image_feature_p)
258
+ image_feature = torch.cat(image_patches_features, dim=1)
259
+ image_features.append(image_feature)
260
+
261
+
262
+ dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
263
+ dummy_image_features = dummy_image_features_2
264
+ use_im_start_end = True
265
+ new_input_embeds = []
266
+ for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
267
+ if (cur_input_ids == im_patch_token).sum() == 0:
268
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
269
+ new_input_embeds.append(cur_input_embeds)
270
+ continue
271
+
272
+ if use_im_start_end:
273
+ if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
274
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
275
+
276
+ image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
277
+ for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
278
+ per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
279
+ num_patches = per_cur_image_features.shape[0]
280
+
281
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
282
+ raise ValueError("The image end token should follow the image start token.")
283
+
284
+ cur_input_embeds = torch.cat(
285
+ (
286
+ cur_input_embeds[:image_start_token_pos+1],
287
+ per_cur_image_features,
288
+ cur_input_embeds[image_start_token_pos + num_patches + 1:]
289
+ ),
290
+ dim=0
291
+ )
292
+
293
+
294
+ new_input_embeds.append(cur_input_embeds)
295
+ else:
296
+ raise NotImplementedError
297
+
298
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
299
+
300
+ return super(GOTQwenModel, self).forward(
301
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
302
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
303
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
304
+ return_dict=return_dict
305
+ )
306
+
307
+
308
+
309
+ class GOTQwenForCausalLM(Qwen2ForCausalLM):
310
+ config_class = GOTConfig
311
+ # supports_gradient_checkpointing = True
312
+
313
+ def __init__(self, config):
314
+ super(Qwen2ForCausalLM, self).__init__(config)
315
+ self.model = GOTQwenModel(config)
316
+
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+
320
+ # Initialize weights and apply final processing
321
+ self.post_init()
322
+
323
+ def get_model(self):
324
+ return self.model
325
+
326
+ def forward(
327
+ self,
328
+ input_ids: torch.LongTensor = None,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
332
+ inputs_embeds: Optional[torch.FloatTensor] = None,
333
+ labels: Optional[torch.LongTensor] = None,
334
+ use_cache: Optional[bool] = None,
335
+ output_attentions: Optional[bool] = None,
336
+ output_hidden_states: Optional[bool] = None,
337
+ images: Optional[torch.FloatTensor] = None,
338
+ return_dict: Optional[bool] = None,
339
+
340
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
341
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342
+ output_hidden_states = (
343
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
344
+ )
345
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
+
347
+ outputs = self.model(
348
+ input_ids=input_ids,
349
+ past_key_values=past_key_values,
350
+ attention_mask=attention_mask,
351
+ position_ids=position_ids,
352
+ inputs_embeds=inputs_embeds,
353
+ use_cache=use_cache,
354
+ output_attentions=output_attentions,
355
+ output_hidden_states=output_hidden_states,
356
+ images=images,
357
+ return_dict=return_dict
358
+
359
+ )
360
+
361
+ hidden_states = outputs[0]
362
+ logits = self.lm_head(hidden_states)
363
+ logits = logits.float()
364
+
365
+ # logits
366
+
367
+ loss = None
368
+ if labels is not None:
369
+ # Shift so that tokens < n predict n
370
+ shift_logits = logits[..., :-1, :].contiguous()
371
+ shift_labels = labels[..., 1:].contiguous()
372
+ # Flatten the tokens
373
+ loss_fct = CrossEntropyLoss()
374
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
375
+ shift_labels = shift_labels.view(-1)
376
+ # Enable model parallelism
377
+ shift_labels = shift_labels.to(shift_logits.device)
378
+ loss = loss_fct(shift_logits, shift_labels)
379
+
380
+ if not return_dict:
381
+ output = (logits,) + outputs[1:]
382
+ return (loss,) + output if loss is not None else output
383
+
384
+ return CausalLMOutputWithPast(
385
+ loss=loss,
386
+ logits=logits,
387
+ past_key_values=outputs.past_key_values,
388
+ hidden_states=outputs.hidden_states,
389
+ attentions=outputs.attentions,
390
+ )
391
+
392
+
393
+ def prepare_inputs_for_generation(
394
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
+ ):
396
+ # Omit tokens covered by past_key_values
397
+ if past_key_values is not None:
398
+ if isinstance(past_key_values, Cache):
399
+ cache_length = past_key_values.get_seq_length()
400
+ past_length = past_key_values.seen_tokens
401
+ max_cache_length = past_key_values.get_max_length()
402
+ else:
403
+ cache_length = past_length = past_key_values[0][0].shape[2]
404
+ max_cache_length = None
405
+
406
+ # Keep only the unprocessed tokens:
407
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
408
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
409
+ # input)
410
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
411
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
412
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
413
+ # input_ids based on the past_length.
414
+ elif past_length < input_ids.shape[1]:
415
+ input_ids = input_ids[:, past_length:]
416
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
417
+
418
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
419
+ if (
420
+ max_cache_length is not None
421
+ and attention_mask is not None
422
+ and cache_length + input_ids.shape[1] > max_cache_length
423
+ ):
424
+ attention_mask = attention_mask[:, -max_cache_length:]
425
+
426
+ position_ids = kwargs.get("position_ids", None)
427
+ if attention_mask is not None and position_ids is None:
428
+ # create position_ids on the fly for batch generation
429
+ position_ids = attention_mask.long().cumsum(-1) - 1
430
+ position_ids.masked_fill_(attention_mask == 0, 1)
431
+ if past_key_values:
432
+ position_ids = position_ids[:, -input_ids.shape[1] :]
433
+
434
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
435
+ if inputs_embeds is not None and past_key_values is None:
436
+ model_inputs = {"inputs_embeds": inputs_embeds}
437
+ else:
438
+ model_inputs = {"input_ids": input_ids}
439
+
440
+ model_inputs.update(
441
+ {
442
+ "position_ids": position_ids,
443
+ "past_key_values": past_key_values,
444
+ "use_cache": kwargs.get("use_cache"),
445
+ "attention_mask": attention_mask,
446
+ "images": kwargs.get("images", None),
447
+ }
448
+ )
449
+ return model_inputs
450
+
451
+ def initialize_vision_tokenizer(
452
+ self,
453
+ tokenizer,
454
+ freeze_lm_model=False,
455
+ pretrained_stage1_model=None,
456
+ device="cuda"
457
+ ):
458
+ config = self.get_model().config
459
+
460
+
461
+ self.resize_token_embeddings(len(tokenizer))
462
+
463
+ config.im_patch_token = 151859
464
+
465
+ config.use_im_start_end = True
466
+
467
+ if config.use_im_start_end:
468
+ self.resize_token_embeddings(len(tokenizer))
469
+ config.im_start_token, config.im_end_token = 151857, 151858
470
+
471
+ def load_image(self, image_file):
472
+ if image_file.startswith('http') or image_file.startswith('https'):
473
+ response = requests.get(image_file)
474
+ image = Image.open(BytesIO(response.content)).convert('RGB')
475
+ else:
476
+ image = Image.open(image_file).convert('RGB')
477
+ return image
478
+
479
+ def disable_torch_init(self):
480
+ """
481
+ Disable the redundant torch default initialization to accelerate model creation.
482
+ """
483
+ import torch
484
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
+
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False):
488
+
489
+ self.disable_torch_init()
490
+
491
+
492
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
+
494
+ use_im_start_end = True
495
+
496
+ image_token_len = 256
497
+
498
+ image = self.load_image(image_file)
499
+
500
+ w, h = image.size
501
+
502
+ if ocr_type == 'format':
503
+ qs = 'OCR with format: '
504
+ else:
505
+ qs = 'OCR: '
506
+
507
+ if ocr_box:
508
+ bbox = eval(ocr_box)
509
+ if len(bbox) == 2:
510
+ bbox[0] = int(bbox[0]/w*1000)
511
+ bbox[1] = int(bbox[1]/h*1000)
512
+ if len(bbox) == 4:
513
+ bbox[0] = int(bbox[0]/w*1000)
514
+ bbox[1] = int(bbox[1]/h*1000)
515
+ bbox[2] = int(bbox[2]/w*1000)
516
+ bbox[3] = int(bbox[3]/h*1000)
517
+ if ocr_type == 'format':
518
+ qs = str(bbox) + ' ' + 'OCR with format: '
519
+ else:
520
+ qs = str(bbox) + ' ' + 'OCR: '
521
+
522
+ if ocr_color:
523
+ if ocr_type == 'format':
524
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
525
+ else:
526
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
527
+
528
+ if use_im_start_end:
529
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
530
+ else:
531
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
532
+
533
+
534
+ conv_mpt = Conversation(
535
+ system="""<|im_start|>system
536
+ You should follow the instructions carefully and explain your answers in detail.""",
537
+ # system = None,
538
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
539
+ version="mpt",
540
+ messages=(),
541
+ offset=0,
542
+ sep_style=SeparatorStyle.MPT,
543
+ sep="<|im_end|>",
544
+ )
545
+
546
+ conv = conv_mpt.copy()
547
+ conv.append_message(conv.roles[0], qs)
548
+ conv.append_message(conv.roles[1], None)
549
+ prompt = conv.get_prompt()
550
+
551
+ print(prompt)
552
+
553
+ inputs = tokenizer([prompt])
554
+
555
+ image_tensor_1 = image_processor_high(image)
556
+
557
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
558
+
559
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
560
+ keywords = [stop_str]
561
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
562
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
563
+
564
+
565
+ with torch.autocast("cuda", dtype=torch.bfloat16):
566
+ output_ids = self.generate(
567
+ input_ids,
568
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
569
+ do_sample=False,
570
+ num_beams = 1,
571
+ no_repeat_ngram_size = 20,
572
+ streamer=streamer,
573
+ max_new_tokens=4096,
574
+ stopping_criteria=[stopping_criteria]
575
+ )
576
+
577
+
578
+ # if render:
579
+ # print('==============rendering===============')
580
+
581
+ # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
582
+
583
+ # if outputs.endswith(stop_str):
584
+ # outputs = outputs[:-len(stop_str)]
585
+ # outputs = outputs.strip()
586
+
587
+ # if '**kern' in outputs:
588
+ # import verovio
589
+ # from cairosvg import svg2png
590
+ # import cv2
591
+ # import numpy as np
592
+ # tk = verovio.toolkit()
593
+ # tk.loadData(outputs)
594
+ # tk.setOptions({"pageWidth": 2100, "footer": 'none',
595
+ # 'barLineWidth': 0.5, 'beamMaxSlope': 15,
596
+ # 'staffLineWidth': 0.2, 'spacingStaff': 6})
597
+ # tk.getPageCount()
598
+ # svg = tk.renderToSVG()
599
+ # svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
600
+
601
+ # svg_to_html(svg, "./results/demo.html")
602
+
603
+ # if ocr_type == 'format' and '**kern' not in outputs:
604
+
605
+
606
+ # if '\\begin{tikzpicture}' not in outputs:
607
+ # html_path = "./render_tools/" + "/content-mmd-to-html.html"
608
+ # html_path_2 = "./results/demo.html"
609
+ # right_num = outputs.count('\\right')
610
+ # left_num = outputs.count('\left')
611
+
612
+ # if right_num != left_num:
613
+ # outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
614
+
615
+
616
+ # outputs = outputs.replace('"', '``').replace('$', '')
617
+
618
+ # outputs_list = outputs.split('\n')
619
+ # gt= ''
620
+ # for out in outputs_list:
621
+ # gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
622
+
623
+ # gt = gt[:-2]
624
+
625
+ # with open(html_path, 'r') as web_f:
626
+ # lines = web_f.read()
627
+ # lines = lines.split("const text =")
628
+ # new_web = lines[0] + 'const text =' + gt + lines[1]
629
+ # else:
630
+ # html_path = "./render_tools/" + "/tikz.html"
631
+ # html_path_2 = "./results/demo.html"
632
+ # outputs = outputs.translate(translation_table)
633
+ # outputs_list = outputs.split('\n')
634
+ # gt= ''
635
+ # for out in outputs_list:
636
+ # if out:
637
+ # if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
638
+ # while out[-1] == ' ':
639
+ # out = out[:-1]
640
+ # if out is None:
641
+ # break
642
+
643
+ # if out:
644
+ # if out[-1] != ';':
645
+ # gt += out[:-1] + ';\n'
646
+ # else:
647
+ # gt += out + '\n'
648
+ # else:
649
+ # gt += out + '\n'
650
+
651
+
652
+ # with open(html_path, 'r') as web_f:
653
+ # lines = web_f.read()
654
+ # lines = lines.split("const text =")
655
+ # new_web = lines[0] + gt + lines[1]
656
+
657
+ # with open(html_path_2, 'w') as web_f_new:
658
+ # web_f_new.write(new_web)
659
+