HaoZhang534 commited on
Commit
a65550c
1 Parent(s): d260573
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +288 -0
  2. llava/__init__.py +1 -0
  3. llava/__pycache__/__init__.cpython-310.pyc +0 -0
  4. llava/__pycache__/constants.cpython-310.pyc +0 -0
  5. llava/__pycache__/conversation.cpython-310.pyc +0 -0
  6. llava/__pycache__/mm_utils.cpython-310.pyc +0 -0
  7. llava/__pycache__/utils.cpython-310.pyc +0 -0
  8. llava/constants.py +12 -0
  9. llava/conversation.py +554 -0
  10. llava/eval/evaluate_interleave.py +339 -0
  11. llava/eval/model_vqa.py +240 -0
  12. llava/mm_utils.py +381 -0
  13. llava/model/__init__.py +20 -0
  14. llava/model/__pycache__/__init__.cpython-310.pyc +0 -0
  15. llava/model/__pycache__/builder.cpython-310.pyc +0 -0
  16. llava/model/__pycache__/llava_arch.cpython-310.pyc +0 -0
  17. llava/model/apply_delta.py +47 -0
  18. llava/model/builder.py +250 -0
  19. llava/model/consolidate.py +30 -0
  20. llava/model/language_model/__pycache__/llava_gemma.cpython-310.pyc +0 -0
  21. llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc +0 -0
  22. llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc +0 -0
  23. llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc +0 -0
  24. llava/model/language_model/__pycache__/llava_qwen.cpython-310.pyc +0 -0
  25. llava/model/language_model/llava_gemma.py +122 -0
  26. llava/model/language_model/llava_llama.py +131 -0
  27. llava/model/language_model/llava_mistral.py +127 -0
  28. llava/model/language_model/llava_mixtral.py +122 -0
  29. llava/model/language_model/llava_mpt.py +105 -0
  30. llava/model/language_model/llava_qwen.py +128 -0
  31. llava/model/language_model/llava_qwen_moe.py +128 -0
  32. llava/model/llava_arch.py +389 -0
  33. llava/model/make_delta.py +52 -0
  34. llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  35. llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  36. llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
  37. llava/model/multimodal_encoder/builder.py +14 -0
  38. llava/model/multimodal_encoder/clip_encoder.py +114 -0
  39. llava/model/multimodal_encoder/siglip_encoder.py +620 -0
  40. llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  41. llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc +0 -0
  42. llava/model/multimodal_projector/builder.py +65 -0
  43. llava/model/multimodal_projector/pooler_projector.py +33 -0
  44. llava/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc +0 -0
  45. llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc +0 -0
  46. llava/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc +0 -0
  47. llava/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc +0 -0
  48. llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc +0 -0
  49. llava/model/multimodal_resampler/builder.py +34 -0
  50. llava/model/multimodal_resampler/masked_drop.py +80 -0
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # from .demo_modelpart import InferenceDemo
3
+ import gradio as gr
4
+ import os
5
+ # import time
6
+ import cv2
7
+
8
+
9
+ # import copy
10
+ import torch
11
+ # import random
12
+ import numpy as np
13
+
14
+ from llava import conversation as conversation_lib
15
+ from llava.constants import DEFAULT_IMAGE_TOKEN
16
+
17
+
18
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
19
+ from llava.conversation import conv_templates, SeparatorStyle
20
+ from llava.model.builder import load_pretrained_model
21
+ from llava.utils import disable_torch_init
22
+ from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
23
+
24
+ from PIL import Image
25
+
26
+ import requests
27
+ from PIL import Image
28
+ from io import BytesIO
29
+ from transformers import TextStreamer
30
+
31
+ class InferenceDemo(object):
32
+ def __init__(self,args,model_path,tokenizer, model, image_processor, context_len) -> None:
33
+ disable_torch_init()
34
+
35
+
36
+ self.tokenizer, self.model, self.image_processor, self.context_len = tokenizer, model, image_processor, context_len
37
+
38
+ if "llama-2" in model_name.lower():
39
+ conv_mode = "llava_llama_2"
40
+ elif "v1" in model_name.lower():
41
+ conv_mode = "llava_v1"
42
+ elif "mpt" in model_name.lower():
43
+ conv_mode = "mpt"
44
+ elif 'qwen' in model_name.lower():
45
+ conv_mode = "qwen_1_5"
46
+ else:
47
+ conv_mode = "llava_v0"
48
+
49
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
50
+ print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode))
51
+ else:
52
+ args.conv_mode = conv_mode
53
+ self.conv_mode=conv_mode
54
+ self.conversation = conv_templates[args.conv_mode].copy()
55
+ self.num_frames = args.num_frames
56
+
57
+
58
+
59
+ def is_valid_video_filename(name):
60
+ video_extensions = ['avi', 'mp4', 'mov', 'mkv', 'flv', 'wmv', 'mjpeg']
61
+
62
+ ext = name.split('.')[-1].lower()
63
+
64
+ if ext in video_extensions:
65
+ return True
66
+ else:
67
+ return False
68
+
69
+ def sample_frames(video_file, num_frames) :
70
+ video = cv2.VideoCapture(video_file)
71
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
72
+ interval = total_frames // num_frames
73
+ frames = []
74
+ for i in range(total_frames):
75
+ ret, frame = video.read()
76
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
77
+ if not ret:
78
+ continue
79
+ if i % interval == 0:
80
+ frames.append(pil_img)
81
+ video.release()
82
+ return frames
83
+
84
+ def load_image(image_file):
85
+ if image_file.startswith("http") or image_file.startswith("https"):
86
+ response = requests.get(image_file)
87
+ if response.status_code == 200:
88
+ image = Image.open(BytesIO(response.content)).convert("RGB")
89
+ else:
90
+ print('failed to load the image')
91
+ else:
92
+ print('Load image from local file')
93
+ print(image_file)
94
+ image = Image.open(image_file).convert("RGB")
95
+
96
+ return image
97
+
98
+
99
+ def clear_history(history):
100
+
101
+ our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
102
+
103
+ return None
104
+ def clear_response(history):
105
+ for index_conv in range(1, len(history)):
106
+ # loop until get a text response from our model.
107
+ conv = history[-index_conv]
108
+ if not (conv[0] is None):
109
+ break
110
+ question = history[-index_conv][0]
111
+ history = history[:-index_conv]
112
+ return history, question
113
+
114
+ def print_like_dislike(x: gr.LikeData):
115
+ print(x.index, x.value, x.liked)
116
+
117
+
118
+
119
+ def add_message(history, message):
120
+ # history=[]
121
+ global our_chatbot
122
+ if len(history)==0:
123
+ our_chatbot = InferenceDemo(args,model_path,tokenizer, model, image_processor, context_len)
124
+
125
+ for x in message["files"]:
126
+ history.append(((x,), None))
127
+ if message["text"] is not None:
128
+ history.append((message["text"], None))
129
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
130
+
131
+ def bot(history):
132
+ text=history[-1][0]
133
+ images_this_term=[]
134
+ text_this_term=''
135
+ # import pdb;pdb.set_trace()
136
+ num_new_images = 0
137
+ for i,message in enumerate(history[:-1]):
138
+ if type(message[0]) is tuple:
139
+ images_this_term.append(message[0][0])
140
+ if is_valid_video_filename(message[0][0]):
141
+ num_new_images+=our_chatbot.num_frames
142
+ else:
143
+ num_new_images+=1
144
+ else:
145
+ num_new_images=0
146
+
147
+ # for message in history[-i-1:]:
148
+ # images_this_term.append(message[0][0])
149
+
150
+ assert len(images_this_term)>0, "must have an image"
151
+ # image_files = (args.image_file).split(',')
152
+ # image = [load_image(f) for f in images_this_term if f]
153
+ image_list=[]
154
+ for f in images_this_term:
155
+ if is_valid_video_filename(f):
156
+ image_list+=sample_frames(f, our_chatbot.num_frames)
157
+ else:
158
+ image_list.append(load_image(f))
159
+ image_tensor = [our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][0].half().to(our_chatbot.model.device) for f in image_list]
160
+
161
+ image_tensor = torch.stack(image_tensor)
162
+ image_token = DEFAULT_IMAGE_TOKEN*num_new_images
163
+ # if our_chatbot.model.config.mm_use_im_start_end:
164
+ # inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
165
+ # else:
166
+ inp=text
167
+ inp = image_token+ "\n" + inp
168
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
169
+ # image = None
170
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
171
+ prompt = our_chatbot.conversation.get_prompt()
172
+
173
+ input_ids = tokenizer_image_token(prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(our_chatbot.model.device)
174
+ stop_str = our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2
175
+ keywords = [stop_str]
176
+ stopping_criteria = KeywordsStoppingCriteria(keywords, our_chatbot.tokenizer, input_ids)
177
+ streamer = TextStreamer(our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
178
+ # import pdb;pdb.set_trace()
179
+ with torch.inference_mode():
180
+ output_ids = our_chatbot.model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=False, stopping_criteria=[stopping_criteria])
181
+
182
+ outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
183
+ if outputs.endswith(stop_str):
184
+ outputs = outputs[:-len(stop_str)]
185
+ our_chatbot.conversation.messages[-1][-1] = outputs
186
+
187
+ history[-1]=[text,outputs]
188
+
189
+ return history
190
+ txt = gr.Textbox(
191
+ scale=4,
192
+ show_label=False,
193
+ placeholder="Enter text and press enter.",
194
+ container=False,
195
+ )
196
+ with gr.Blocks() as demo:
197
+ # Informations
198
+ title_markdown = ("""
199
+ # LLaVA-NeXT Interleave
200
+ [[Blog]](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/) [[Code]](https://github.com/LLaVA-VL/LLaVA-NeXT) [[Model]](https://huggingface.co/lmms-lab/llava-next-interleave-7b)
201
+ """)
202
+ tos_markdown = ("""
203
+ ### TODO!. Terms of use
204
+ By using this service, users are required to agree to the following terms:
205
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
206
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
207
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
208
+ """)
209
+ learn_more_markdown = ("""
210
+ ### TODO!. License
211
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
212
+ """)
213
+ models = [
214
+ "LLaVA-Interleave-7B",
215
+ ]
216
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
217
+ gr.Markdown(title_markdown)
218
+
219
+ chatbot = gr.Chatbot(
220
+ [],
221
+ elem_id="chatbot",
222
+ bubble_full_width=False
223
+ )
224
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image","video"], placeholder="Enter message or upload file...", show_label=False)
225
+
226
+
227
+
228
+ with gr.Row():
229
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
230
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
231
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
232
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
233
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
234
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
235
+ chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
236
+ bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
237
+ bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
238
+
239
+ chatbot.like(print_like_dislike, None, None)
240
+ clear_btn.click(fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all")
241
+ with gr.Column():
242
+ gr.Examples(examples=[
243
+ [{"files": [f"{cur_dir}/examples/code1.jpeg",f"{cur_dir}/examples/code2.jpeg"], "text": "Please pay attention to the movement of the object from the first image to the second image, then write a HTML code to show this movement."}],
244
+ [{"files": [f"{cur_dir}/examples/shub.jpg",f"{cur_dir}/examples/shuc.jpg",f"{cur_dir}/examples/shud.jpg"], "text": "what is fun about the images?"}],
245
+ [{"files": [f"{cur_dir}/examples/iphone-15-price-1024x576.jpg",f"{cur_dir}/examples/dynamic-island-1024x576.jpg",f"{cur_dir}/examples/iphone-15-colors-1024x576.jpg",f"{cur_dir}/examples/Iphone-15-Usb-c-charger-1024x576.jpg",f"{cur_dir}/examples/A-17-processors-1024x576.jpg"], "text": "The images are the PPT of iPhone 15 review. can you summarize the main information?"}],
246
+ [{"files": [f"{cur_dir}/examples/fangao3.jpeg",f"{cur_dir}/examples/fangao2.jpeg",f"{cur_dir}/examples/fangao1.jpeg"], "text": "Do you kown who draw these paintings?"}],
247
+ [{"files": [f"{cur_dir}/examples/oprah-winfrey-resume.png",f"{cur_dir}/examples/steve-jobs-resume.jpg"], "text": "Hi, there are two candidates, can you provide a brief description for each of them for me?"}],
248
+ [{"files": [f"{cur_dir}/examples/original_bench.jpeg",f"{cur_dir}/examples/changed_bench.jpeg"], "text": "How to edit image1 to make it look like image2?"}],
249
+ [{"files": [f"{cur_dir}/examples/twitter2.jpeg",f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
250
+ # [{"files": [f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
251
+ # [{"files": [f"playground/demo/examples/lion1_.mp4",f"playground/demo/examples/lion2_.mp4"], "text": "The input contains two videos, the first half is the first video and the second half is the second video. What is the difference between the two videos?"}],
252
+
253
+
254
+
255
+
256
+ ], inputs=[chat_input], label="Compare images: ")
257
+
258
+ demo.queue()
259
+ if __name__ == "__main__":
260
+ import argparse
261
+ argparser = argparse.ArgumentParser()
262
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
263
+ argparser.add_argument("--port", default="6123", type=str)
264
+ argparser.add_argument("--model_path", default="", type=str)
265
+ # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
266
+ argparser.add_argument("--model-base", type=str, default=None)
267
+ argparser.add_argument("--num-gpus", type=int, default=1)
268
+ argparser.add_argument("--conv-mode", type=str, default=None)
269
+ argparser.add_argument("--temperature", type=float, default=0.2)
270
+ argparser.add_argument("--max-new-tokens", type=int, default=512)
271
+ argparser.add_argument("--num_frames", type=int, default=16)
272
+ argparser.add_argument("--load-8bit", action="store_true")
273
+ argparser.add_argument("--load-4bit", action="store_true")
274
+ argparser.add_argument("--debug", action="store_true")
275
+
276
+ args = argparser.parse_args()
277
+ model_path = args.model_path
278
+ filt_invalid="cut"
279
+ model_name = get_model_name_from_path(args.model_path)
280
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
281
+ our_chatbot = None
282
+ # import pdb;pdb.set_trace()
283
+ try:
284
+ demo.launch(server_name=args.server_name, server_port=int(args.port),share=True)
285
+ except Exception as e:
286
+ args.port=int(args.port)+1
287
+ print(f"Port {args.port} is occupied, try port {args.port}")
288
+ demo.launch(server_name=args.server_name, server_port=int(args.port),share=True)
llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
llava/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (212 Bytes). View file
 
llava/__pycache__/constants.cpython-310.pyc ADDED
Binary file (474 Bytes). View file
 
llava/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
llava/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
llava/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.3 kB). View file
 
llava/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
llava/conversation.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Any, Dict, Union, Tuple
4
+ import re
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+
14
+ SINGLE = auto()
15
+ TWO = auto()
16
+ MPT = auto()
17
+ PLAIN = auto()
18
+ CHATML = auto()
19
+ LLAMA_2 = auto()
20
+ LLAMA_3 = auto()
21
+ QWEN = auto()
22
+ GEMMA = auto()
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class Conversation:
27
+ """A class that keeps all conversation history."""
28
+
29
+ system: str
30
+ roles: List[str]
31
+ messages: List[List[str]]
32
+ offset: int
33
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
34
+ sep: str = "###"
35
+ sep2: str = None
36
+ version: str = "Unknown"
37
+
38
+ tokenizer_id: str = ""
39
+ tokenizer: Any = None
40
+ # Stop criteria (the default one is EOS token)
41
+ stop_str: Union[str, List[str]] = None
42
+ # Stops generation if meeting any token in this list
43
+ stop_token_ids: List[int] = None
44
+
45
+ skip_next: bool = False
46
+
47
+ def get_prompt(self):
48
+ messages = self.messages
49
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
50
+ messages = self.messages.copy()
51
+ init_role, init_msg = messages[0].copy()
52
+ init_msg = init_msg[0]
53
+ if "mmtag" in self.version:
54
+ init_msg = init_msg.replace("<image>", "").strip()
55
+ messages[0] = (init_role, init_msg)
56
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
57
+ messages.insert(1, (self.roles[1], "Received."))
58
+ elif not init_msg.startswith("<image>"):
59
+ init_msg = init_msg.replace("<image>", "").strip()
60
+ messages[0] = (init_role, "<image>\n" + init_msg)
61
+ else:
62
+ messages[0] = (init_role, init_msg)
63
+
64
+ if self.sep_style == SeparatorStyle.SINGLE:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + ": " + message + self.sep
71
+ else:
72
+ ret += role + ":"
73
+
74
+ elif self.sep_style == SeparatorStyle.TWO:
75
+ seps = [self.sep, self.sep2]
76
+ ret = self.system + seps[0]
77
+ for i, (role, message) in enumerate(messages):
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + ": " + message + seps[i % 2]
82
+ else:
83
+ ret += role + ":"
84
+
85
+ elif self.sep_style == SeparatorStyle.CHATML:
86
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
87
+ for role, message in messages:
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, images = message
91
+ message = "<image>" * len(images) + message
92
+ ret += role + "\n" + message + self.sep + "\n"
93
+ else:
94
+ ret += role + "\n"
95
+ return ret
96
+
97
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
98
+ chat_template_messages = [{"role": "system", "content": self.system}]
99
+ for role, message in messages:
100
+ if message:
101
+ if type(message) is tuple:
102
+ message, images = message
103
+ message = "<image>" * len(images) + message
104
+ chat_template_messages.append({"role": role, "content": message})
105
+
106
+ # print(chat_template_messages)
107
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
108
+ # ret = "" if self.system == "" else self.system + self.sep + "\n"
109
+ # for role, message in messages:
110
+ # if message:
111
+ # if type(message) is tuple:
112
+ # message, images = message
113
+ # message = "<image>" * len(images) + message
114
+ # ret += role + "\n" + message + self.sep + "\n"
115
+ # else:
116
+ # ret += role + "\n"
117
+ # return ret
118
+
119
+ elif self.sep_style == SeparatorStyle.MPT:
120
+ ret = self.system + self.sep
121
+ for role, message in messages:
122
+ if message:
123
+ if type(message) is tuple:
124
+ message, _, _ = message
125
+ ret += role + message + self.sep
126
+ else:
127
+ ret += role
128
+
129
+ elif self.sep_style == SeparatorStyle.GEMMA:
130
+ ret = ""
131
+ for i, (role, message) in enumerate(messages):
132
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
133
+ if message:
134
+ if type(message) is tuple:
135
+ message, _, _ = message
136
+ ret += role + message + self.sep
137
+ else:
138
+ ret += role
139
+
140
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
141
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
142
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
143
+ ret = ""
144
+
145
+ for i, (role, message) in enumerate(messages):
146
+ if i == 0:
147
+ assert message, "first message should not be none"
148
+ assert role == self.roles[0], "first message should come from user"
149
+ if message:
150
+ if type(message) is tuple:
151
+ message, _, _ = message
152
+ if i == 0:
153
+ message = wrap_sys(self.system) + message
154
+ if i % 2 == 0:
155
+ message = wrap_inst(message)
156
+ ret += self.sep + message
157
+ else:
158
+ ret += " " + message + " " + self.sep2
159
+ else:
160
+ ret += ""
161
+ ret = ret.lstrip(self.sep)
162
+
163
+ elif self.sep_style == SeparatorStyle.PLAIN:
164
+ seps = [self.sep, self.sep2]
165
+ ret = self.system
166
+ for i, (role, message) in enumerate(messages):
167
+ if message:
168
+ if type(message) is tuple:
169
+ message, _, _ = message
170
+ ret += message + seps[i % 2]
171
+ else:
172
+ ret += ""
173
+ else:
174
+ raise ValueError(f"Invalid style: {self.sep_style}")
175
+
176
+ return ret
177
+
178
+ def append_message(self, role, message):
179
+ self.messages.append([role, message])
180
+
181
+ def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
182
+ if image_process_mode == "Pad":
183
+
184
+ def expand2square(pil_img, background_color=(122, 116, 104)):
185
+ width, height = pil_img.size
186
+ if width == height:
187
+ return pil_img
188
+ elif width > height:
189
+ result = Image.new(pil_img.mode, (width, width), background_color)
190
+ result.paste(pil_img, (0, (width - height) // 2))
191
+ return result
192
+ else:
193
+ result = Image.new(pil_img.mode, (height, height), background_color)
194
+ result.paste(pil_img, ((height - width) // 2, 0))
195
+ return result
196
+
197
+ image = expand2square(image)
198
+ elif image_process_mode in ["Default", "Crop"]:
199
+ pass
200
+ elif image_process_mode == "Resize":
201
+ image = image.resize((336, 336))
202
+ else:
203
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
204
+
205
+ max_hw, min_hw = max(image.size), min(image.size)
206
+ aspect_ratio = max_hw / min_hw
207
+ max_len, min_len = 672, 448
208
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
209
+ longest_edge = int(shortest_edge * aspect_ratio)
210
+ W, H = image.size
211
+ if H > W:
212
+ H, W = longest_edge, shortest_edge
213
+ else:
214
+ H, W = shortest_edge, longest_edge
215
+ image = image.resize((W, H))
216
+ if return_pil:
217
+ return image
218
+ else:
219
+ buffered = BytesIO()
220
+ image.save(buffered, format=image_format)
221
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
222
+ return img_b64_str
223
+
224
+ def get_images(self, return_pil=False):
225
+ images = []
226
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
227
+ if i % 2 == 0:
228
+ if type(msg) is tuple:
229
+ msg, image, image_process_mode = msg
230
+ if type(image) != list:
231
+ image = [image]
232
+ for img in image:
233
+ img = self.process_image(img, image_process_mode, return_pil=return_pil)
234
+ images.append(img)
235
+ return images
236
+
237
+ def to_gradio_chatbot(self):
238
+ ret = []
239
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
240
+ if i % 2 == 0:
241
+ if type(msg) is tuple:
242
+ msg, image, image_process_mode = msg
243
+ if type(image) != list:
244
+ image = [image]
245
+ if len(image) == 1:
246
+ msg = "<image>\n" + msg.replace("<image>", "").strip()
247
+ else:
248
+ msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
249
+ for img in image:
250
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
251
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}"/>'
252
+ msg = msg.replace("<image>", img_str, 1).strip()
253
+ if len(msg) > 0:
254
+ ret.append([msg, None])
255
+ else:
256
+ ret.append([msg, None])
257
+ else:
258
+ ret[-1][-1] = msg
259
+ return ret
260
+
261
+ def copy(self):
262
+ return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
263
+
264
+ def dict(self):
265
+ if len(self.get_images()) > 0:
266
+ return {
267
+ "system": self.system,
268
+ "roles": self.roles,
269
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
270
+ "offset": self.offset,
271
+ "sep": self.sep,
272
+ "sep2": self.sep2,
273
+ }
274
+ return {
275
+ "system": self.system,
276
+ "roles": self.roles,
277
+ "messages": self.messages,
278
+ "offset": self.offset,
279
+ "sep": self.sep,
280
+ "sep2": self.sep2,
281
+ }
282
+
283
+
284
+ conv_vicuna_v0 = Conversation(
285
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
286
+ roles=("Human", "Assistant"),
287
+ messages=[
288
+ ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
289
+ [
290
+ "Assistant",
291
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
292
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
293
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
294
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
295
+ "renewable and non-renewable energy sources:\n"
296
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
297
+ "energy sources are finite and will eventually run out.\n"
298
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
299
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
300
+ "and other negative effects.\n"
301
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
302
+ "have lower operational costs than non-renewable sources.\n"
303
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
304
+ "locations than non-renewable sources.\n"
305
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
306
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
307
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
308
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
309
+ ],
310
+ ],
311
+ offset=2,
312
+ sep_style=SeparatorStyle.SINGLE,
313
+ sep="###",
314
+ )
315
+
316
+ conv_vicuna_v1 = Conversation(
317
+ system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
318
+ roles=("USER", "ASSISTANT"),
319
+ version="v1",
320
+ messages=[],
321
+ offset=0,
322
+ sep_style=SeparatorStyle.TWO,
323
+ sep=" ",
324
+ sep2="</s>",
325
+ )
326
+
327
+ conv_llama_2 = Conversation(
328
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
329
+
330
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
331
+ roles=("USER", "ASSISTANT"),
332
+ version="llama_v2",
333
+ messages=[],
334
+ offset=0,
335
+ sep_style=SeparatorStyle.LLAMA_2,
336
+ sep="<s>",
337
+ sep2="</s>",
338
+ )
339
+
340
+ conv_llava_llama_2 = Conversation(
341
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
342
+ roles=("USER", "ASSISTANT"),
343
+ version="llama_v2",
344
+ messages=[],
345
+ offset=0,
346
+ sep_style=SeparatorStyle.LLAMA_2,
347
+ sep="<s>",
348
+ sep2="</s>",
349
+ )
350
+
351
+ try:
352
+ llama3_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
353
+ except Exception as e:
354
+ print("Error loading llama3 tokenizer")
355
+ print(e)
356
+
357
+ # conv_llava_llama_3 = Conversation(
358
+ # system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
359
+ # roles=("<|start_header_id|>user", "<|start_header_id|>assistant"),
360
+ # version="llama_v3",
361
+ # messages=[],
362
+ # offset=0,
363
+ # sep_style=SeparatorStyle.LLAMA_3,
364
+ # tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
365
+ # tokenizer=llama3_tokenizer,
366
+ # stop_token_ids=[128009],
367
+ # )
368
+
369
+ conv_mistral_instruct = Conversation(
370
+ system="",
371
+ roles=("USER", "ASSISTANT"),
372
+ version="llama_v2",
373
+ messages=[],
374
+ offset=0,
375
+ sep_style=SeparatorStyle.LLAMA_2,
376
+ sep="",
377
+ sep2="</s>",
378
+ )
379
+
380
+ conv_llava_llama_2_simple = Conversation(
381
+ system="Answer the questions about the visual content that the user provides.",
382
+ roles=("USER", "ASSISTANT"),
383
+ version="llama_v2",
384
+ messages=[],
385
+ offset=0,
386
+ sep_style=SeparatorStyle.LLAMA_2,
387
+ sep="<s>",
388
+ sep2="</s>",
389
+ )
390
+
391
+ conv_llava_llama_2_mmtag = Conversation(
392
+ system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
393
+ roles=("USER", "ASSISTANT"),
394
+ version="llama_v2_mmtag",
395
+ messages=[],
396
+ offset=0,
397
+ sep_style=SeparatorStyle.LLAMA_2,
398
+ sep="<s>",
399
+ sep2="</s>",
400
+ )
401
+
402
+ conv_mpt = Conversation(
403
+ system="""<|im_start|>system
404
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
405
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
406
+ version="mpt",
407
+ messages=[],
408
+ offset=0,
409
+ sep_style=SeparatorStyle.MPT,
410
+ sep="<|im_end|>",
411
+ )
412
+
413
+ conv_qwen = Conversation(
414
+ system="""<|im_start|>system
415
+ You are a helpful assistant.""",
416
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
417
+ version="qwen",
418
+ messages=[],
419
+ offset=0,
420
+ sep_style=SeparatorStyle.CHATML,
421
+ sep="<|im_end|>",
422
+ )
423
+
424
+ conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
425
+
426
+ conv_llava_plain = Conversation(
427
+ system="",
428
+ roles=("", ""),
429
+ messages=[],
430
+ offset=0,
431
+ sep_style=SeparatorStyle.PLAIN,
432
+ sep="\n",
433
+ )
434
+
435
+ conv_llava_v0 = Conversation(
436
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
437
+ roles=("Human", "Assistant"),
438
+ messages=[],
439
+ offset=0,
440
+ sep_style=SeparatorStyle.SINGLE,
441
+ sep="###",
442
+ )
443
+
444
+ conv_llava_v0_mmtag = Conversation(
445
+ system="A chat between a curious user and an artificial intelligence assistant. "
446
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
447
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
448
+ roles=("Human", "Assistant"),
449
+ messages=[],
450
+ offset=0,
451
+ sep_style=SeparatorStyle.SINGLE,
452
+ sep="###",
453
+ version="v0_mmtag",
454
+ )
455
+
456
+ conv_llava_v1 = Conversation(
457
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
458
+ roles=("USER", "ASSISTANT"),
459
+ version="v1",
460
+ messages=[],
461
+ offset=0,
462
+ sep_style=SeparatorStyle.TWO,
463
+ sep=" ",
464
+ sep2="</s>",
465
+ )
466
+
467
+ conv_llava_v1_mmtag = Conversation(
468
+ system="A chat between a curious user and an artificial intelligence assistant. "
469
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
470
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
471
+ roles=("USER", "ASSISTANT"),
472
+ messages=[],
473
+ offset=0,
474
+ sep_style=SeparatorStyle.TWO,
475
+ sep=" ",
476
+ sep2="</s>",
477
+ version="v1_mmtag",
478
+ )
479
+
480
+ conv_mistral_orca = Conversation(
481
+ system="""<|im_start|>system
482
+ You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
483
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
484
+ version="mpt",
485
+ messages=[],
486
+ offset=0,
487
+ sep_style=SeparatorStyle.MPT,
488
+ sep="<|im_end|>",
489
+ )
490
+
491
+ conv_mistral_zephyr = Conversation(
492
+ system="""<|system|>
493
+ You are a helpful AI assistant.""",
494
+ roles=("<|user|>\n", "<|assistant|>\n"),
495
+ version="mpt",
496
+ messages=[],
497
+ offset=0,
498
+ sep_style=SeparatorStyle.MPT,
499
+ sep="</s>",
500
+ )
501
+
502
+ conv_mistral_direct = Conversation(
503
+ system="""<|im_start|>system
504
+ Answer the questions.""",
505
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
506
+ version="mpt",
507
+ messages=[],
508
+ offset=0,
509
+ sep_style=SeparatorStyle.MPT,
510
+ sep="<|im_end|>",
511
+ )
512
+
513
+ conv_chatml_direct = Conversation(
514
+ system="""<|im_start|>system
515
+ Answer the questions.""",
516
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
517
+ version="mpt",
518
+ messages=[],
519
+ offset=0,
520
+ sep_style=SeparatorStyle.MPT,
521
+ sep="<|im_end|>",
522
+ )
523
+
524
+ default_conversation = conv_vicuna_v0
525
+ conv_templates = {
526
+ "default": conv_vicuna_v0,
527
+ "v0": conv_vicuna_v0,
528
+ "v1": conv_vicuna_v1,
529
+ "vicuna_v1": conv_vicuna_v1,
530
+ "llama_2": conv_llama_2,
531
+ "mistral_instruct": conv_mistral_instruct,
532
+ "mistral_orca": conv_mistral_orca,
533
+ "mistral_zephyr": conv_mistral_zephyr,
534
+ "mistral_direct": conv_mistral_direct,
535
+ "plain": conv_llava_plain,
536
+ "v0_plain": conv_llava_plain,
537
+ "chatml_direct": conv_chatml_direct,
538
+ "llava_v0": conv_llava_v0,
539
+ "llava_v0_mmtag": conv_llava_v0_mmtag,
540
+ "llava_v1": conv_llava_v1,
541
+ "llava_v1_mmtag": conv_llava_v1_mmtag,
542
+ "llava_llama_2": conv_llava_llama_2,
543
+ # "llava_llama_3": conv_llava_llama_3,
544
+ "llava_llama_2_simple": conv_llava_llama_2_simple,
545
+ "llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
546
+ "llava_mistral_instruct": conv_mistral_instruct,
547
+ "mpt": conv_mpt,
548
+ "qwen_1_5": conv_qwen,
549
+ "gemma_instruct": conv_gemma_instruct,
550
+ }
551
+
552
+
553
+ if __name__ == "__main__":
554
+ print(default_conversation.get_prompt())
llava/eval/evaluate_interleave.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from rouge import Rouge
3
+ import argparse
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+
11
+ spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
12
+ image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
13
+ visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
14
+ visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
15
+ text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
16
+ multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
17
+
18
+ puzzle = ["RAVEN"]
19
+ nlrv2 = ["NLVR2_Mantis"]
20
+ qbench = ["QBench"]
21
+
22
+ class Eval:
23
+ def __init__(self):
24
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
25
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
26
+ self.punct = [
27
+ ";",
28
+ r"/",
29
+ "[",
30
+ "]",
31
+ '"',
32
+ "{",
33
+ "}",
34
+ "(",
35
+ ")",
36
+ "=",
37
+ "+",
38
+ "\\",
39
+ "_",
40
+ "-",
41
+ ">",
42
+ "<",
43
+ "@",
44
+ "`",
45
+ ",",
46
+ "?",
47
+ "!",
48
+ ]
49
+
50
+ def processPunctuation(self, inText):
51
+ outText = inText
52
+ for p in self.punct:
53
+ if (p + " " in inText or " " + p in inText) or (
54
+ re.search(self.commaStrip, inText) != None
55
+ ):
56
+ outText = outText.replace(p, "")
57
+ else:
58
+ outText = outText.replace(p, " ")
59
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
60
+ return outText
61
+
62
+ def process(self, answer):
63
+ answer = answer.replace("\n", " ")
64
+ answer = answer.replace("\t", " ")
65
+ answer = answer.strip()
66
+ answer = self.processPunctuation(answer)
67
+ answer = answer.strip('\'')
68
+ answer = answer.strip('\"')
69
+ answer = answer.strip(')')
70
+ answer = answer.strip('(')
71
+ answer = answer.strip().lower()
72
+ return answer
73
+
74
+ def evaluate_rouge(self,preds):
75
+ rouge = Rouge()
76
+ acc = {'f': []}
77
+ eval_list = []
78
+ for i, res in enumerate(preds):
79
+ sample_id = res['sample_id']
80
+ # print(sample_id)
81
+ gt_ans = self.process(res["gt_response"])
82
+ pred_ans = self.process(res["pred_response"])
83
+ # assert gt_ans != ''
84
+
85
+ if gt_ans == '':
86
+ continue
87
+
88
+ if pred_ans == '':
89
+ s = 0
90
+ else:
91
+ if len(pred_ans) > 512:
92
+ pred_ans = pred_ans[0: 512]
93
+ s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
94
+ acc['f'].append(s)
95
+ eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
96
+ results = {'Rouge-L f': np.mean(acc['f'])}
97
+ return results,eval_list
98
+
99
+
100
+ def judge_multi_choice(self,sample):
101
+ sample_id = sample['sample_id']
102
+ gt_ans = sample["gt_response"]
103
+ pred_ans = sample["pred_response"]
104
+
105
+ if ":" in pred_ans:
106
+ a_list = pred_ans.split(":")
107
+ a_list = [a.strip() for a in a_list ]
108
+ for a in a_list:
109
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
110
+ pred_ans = a
111
+
112
+ if pred_ans == gt_ans:
113
+ return 1
114
+ else:
115
+ return 0
116
+
117
+ def process_sample(self,sample):
118
+ sample["gt_response"] = self.process(sample["gt_response"])
119
+ sample["pred_response"] = self.process(sample["pred_response"])
120
+
121
+ def evaluate_multichoice(self, preditions):
122
+ correct = 0
123
+ eval_list = []
124
+ for i, sample in enumerate(preditions):
125
+ self.process_sample(sample)
126
+ score = self.judge_multi_choice(sample)
127
+ sample_id = sample['sample_id']
128
+ sample['result'] = score
129
+ eval_list.append({'id':str(sample_id),'score':str(score)})
130
+ correct+=score
131
+ return {'Accuracy':correct/len(preditions)},eval_list
132
+
133
+ def evaluate_multi_choice_image(self,preditions):
134
+ correct = 0
135
+ eval_list = []
136
+ for i,sample in enumerate(preditions):
137
+ gt_ans = self.process(sample["gt_response"])
138
+ pred_ans = self.process(sample["pred_response"])
139
+ sample_id = sample['sample_id']
140
+
141
+ if ":" in pred_ans:
142
+ a_list = pred_ans.split(":")
143
+ a_list = [a.strip() for a in a_list ]
144
+ for a in a_list:
145
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
146
+ pred_ans = a
147
+
148
+ if gt_ans == pred_ans:
149
+ score = 1
150
+ else:
151
+ score = 0
152
+ sample_id = sample['sample_id']
153
+ sample['result'] = score
154
+ eval_list.append({'id':str(sample_id),'score':str(score)})
155
+ correct+=score
156
+ return {'Accuracy':correct/len(preditions)},eval_list
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = argparse.ArgumentParser()
161
+ parser.add_argument('--result-dir', type=str, required=True)
162
+
163
+ args = parser.parse_args()
164
+
165
+ result_file = os.path.join(args.result_dir, "result.jsonl")
166
+
167
+ if not os.path.exists(result_file):
168
+ print('No prediction file found')
169
+ exit(0)
170
+ with open(result_file, 'r') as f:
171
+ preds_all = [json.loads(line) for line in f]
172
+
173
+ preds_all_dict = dict()
174
+ for pred in preds_all:
175
+ if pred["dataset"] not in preds_all_dict:
176
+ preds_all_dict[pred["dataset"]] = list()
177
+ preds_all_dict[pred["dataset"]].append(pred)
178
+
179
+ image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
180
+ E = Eval()
181
+
182
+ eval_result_list = dict()
183
+ eval_result_list_detail = dict()
184
+
185
+ for dataset in preds_all_dict:
186
+
187
+ preds = preds_all_dict[dataset]
188
+ question_type = preds[0]["question_type"]
189
+
190
+ if question_type == 'open-ended':
191
+ eval_result, eval_list = E.evaluate_rouge(preds)
192
+
193
+ elif question_type == 'multi-choice' or dataset == 'nlrv2':
194
+ if dataset in image_choice_dataset_list:
195
+ eval_result, eval_list = E.evaluate_multi_choice_image(preds)
196
+ else:
197
+ eval_result, eval_list = E.evaluate_multichoice(preds)
198
+
199
+ else:
200
+ eval_result = 'Dataset not supported'
201
+ print('Dataset not supported')
202
+ exit(0)
203
+
204
+ print(dataset, end = ': ')
205
+ print(eval_result)
206
+
207
+ eval_result_list[dataset] = eval_result
208
+ eval_result_list_detail[dataset] = eval_list
209
+
210
+ os.makedirs(args.result_dir, exist_ok=True)
211
+ with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
212
+ json.dump(eval_result_list, f, indent=4)
213
+
214
+ with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
215
+ json.dump(eval_result_list_detail, f, indent=4)
216
+
217
+
218
+ eval_cat_list = dict()
219
+ print()
220
+
221
+ # spot_the_diff
222
+ score = 0
223
+ count = 0
224
+ for dataset in eval_result_list:
225
+ if dataset in spot_the_diff:
226
+ count += 1
227
+ score += list(eval_result_list[dataset].values())[0]
228
+ if count > 0:
229
+ score /= count
230
+ eval_cat_list["spot_the_diff"] = score
231
+ print("spot_the_diff", end = ': ')
232
+ print('{:.2f}'.format(100 * score))
233
+
234
+ # image_edit_instruct
235
+ score = 0
236
+ count = 0
237
+ for dataset in eval_result_list:
238
+ if dataset in image_edit_instruct:
239
+ count += 1
240
+ score += list(eval_result_list[dataset].values())[0]
241
+ if count > 0:
242
+ score /= count
243
+ eval_cat_list["image_edit_instruct"] = score
244
+ print("image_edit_instruct", end = ': ')
245
+ print('{:.2f}'.format(100 * score))
246
+
247
+ # visual_story_telling
248
+ score = 0
249
+ count = 0
250
+ for dataset in eval_result_list:
251
+ if dataset in visual_story_telling:
252
+ count += 1
253
+ score += list(eval_result_list[dataset].values())[0]
254
+ if count > 0:
255
+ score /= count
256
+ eval_cat_list["visual_story_telling"] = score
257
+ print("visual_story_telling", end = ': ')
258
+ print('{:.2f}'.format(100 * score))
259
+
260
+ # visual_cloze
261
+ score = 0
262
+ count = 0
263
+ for dataset in eval_result_list:
264
+ if dataset in visual_cloze:
265
+ count += 1
266
+ score += list(eval_result_list[dataset].values())[0]
267
+ if count > 0:
268
+ score /= count
269
+ eval_cat_list["visual_cloze"] = score
270
+ print("visual_cloze", end = ': ')
271
+ print('{:.2f}'.format(100 * score))
272
+
273
+ # text_rich_vqa
274
+ score = 0
275
+ count = 0
276
+ for dataset in eval_result_list:
277
+ if dataset in text_rich_vqa:
278
+ count += 1
279
+ score += list(eval_result_list[dataset].values())[0]
280
+ if count > 0:
281
+ score /= count
282
+ eval_cat_list["text_rich_vqa"] = score
283
+ print("text_rich_vqa", end = ': ')
284
+ print('{:.2f}'.format(100 * score))
285
+
286
+ # multi_image_vqa
287
+ score = 0
288
+ count = 0
289
+ for dataset in eval_result_list:
290
+ if dataset in multi_image_vqa:
291
+ count += 1
292
+ score += list(eval_result_list[dataset].values())[0]
293
+ if count > 0:
294
+ score /= count
295
+ eval_cat_list["multi_image_vqa"] = score
296
+ print("multi_image_vqa", end = ': ')
297
+ print('{:.2f}'.format(100 * score))
298
+
299
+ # puzzle
300
+ score = 0
301
+ count = 0
302
+ for dataset in eval_result_list:
303
+ if dataset in puzzle:
304
+ count += 1
305
+ score += list(eval_result_list[dataset].values())[0]
306
+ if count > 0:
307
+ score /= count
308
+ eval_cat_list["puzzle"] = score
309
+ print("puzzle", end = ': ')
310
+ print('{:.2f}'.format(100 * score))
311
+
312
+ # nlrv2
313
+ score = 0
314
+ count = 0
315
+ for dataset in eval_result_list:
316
+ if dataset in nlrv2:
317
+ count += 1
318
+ score += list(eval_result_list[dataset].values())[0]
319
+ if count > 0:
320
+ score /= count
321
+ eval_cat_list["nlrv2"] = score
322
+ print("nlrv2", end = ': ')
323
+ print('{:.2f}'.format(100 * score))
324
+
325
+ # qbench
326
+ score = 0
327
+ count = 0
328
+ for dataset in eval_result_list:
329
+ if dataset in qbench:
330
+ count += 1
331
+ score += list(eval_result_list[dataset].values())[0]
332
+ if count > 0:
333
+ score /= count
334
+ eval_cat_list["qbench"] = score
335
+ print("qbench", end = ': ')
336
+ print('{:.2f}'.format(100 * score))
337
+
338
+ with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
339
+ json.dump(eval_cat_list, f, indent=4)
llava/eval/model_vqa.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from llava.conversation import conv_templates, SeparatorStyle
10
+ from llava.model.builder import load_pretrained_model
11
+ from llava.utils import disable_torch_init
12
+ from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
15
+ from typing import Dict, Optional, Sequence, List
16
+ import transformers
17
+ import re
18
+
19
+ from PIL import Image
20
+ import math
21
+
22
+
23
+ def split_list(lst, n):
24
+ """Split a list into n (roughly) equal-sized chunks"""
25
+ chunk_size = math.ceil(len(lst) / n) # integer division
26
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
27
+
28
+
29
+ def get_chunk(lst, n, k):
30
+ chunks = split_list(lst, n)
31
+ return chunks[k]
32
+
33
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
34
+ roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
35
+
36
+ im_start, im_end = tokenizer.additional_special_tokens_ids
37
+ nl_tokens = tokenizer("\n").input_ids
38
+ _system = tokenizer("system").input_ids + nl_tokens
39
+ _user = tokenizer("user").input_ids + nl_tokens
40
+ _assistant = tokenizer("assistant").input_ids + nl_tokens
41
+
42
+ # Apply prompt templates
43
+ input_ids, targets = [], []
44
+
45
+ source = sources
46
+ if roles[source[0]["from"]] != roles["human"]:
47
+ source = source[1:]
48
+
49
+ input_id, target = [], []
50
+ system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
51
+ input_id += system
52
+ target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
53
+ assert len(input_id) == len(target)
54
+ for j, sentence in enumerate(source):
55
+ role = roles[sentence["from"]]
56
+ if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
57
+ num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
58
+ texts = sentence["value"].split('<image>')
59
+ _input_id = tokenizer(role).input_ids + nl_tokens
60
+ for i,text in enumerate(texts):
61
+ _input_id += tokenizer(text).input_ids
62
+ if i<len(texts)-1:
63
+ _input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
64
+ _input_id += [im_end] + nl_tokens
65
+ assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
66
+ else:
67
+ if sentence["value"] is None:
68
+ _input_id = tokenizer(role).input_ids + nl_tokens
69
+ else:
70
+ _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
71
+ input_id += _input_id
72
+ if role == "<|im_start|>user":
73
+ _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
74
+ elif role == "<|im_start|>assistant":
75
+ _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
76
+ else:
77
+ raise NotImplementedError
78
+ target += _target
79
+
80
+ input_ids.append(input_id)
81
+ targets.append(target)
82
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
83
+ targets = torch.tensor(targets, dtype=torch.long)
84
+ return input_ids
85
+
86
+ def eval_model(args):
87
+
88
+ # Model
89
+ disable_torch_init()
90
+ model_path = os.path.expanduser(args.model_path)
91
+ model_name = get_model_name_from_path(model_path)
92
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
93
+
94
+ # Data
95
+ with open(os.path.expanduser(args.question_file)) as f:
96
+ questions = json.load(f)
97
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
98
+ answers_file = os.path.expanduser(args.answers_file)
99
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
100
+ ans_file = open(answers_file, "w")
101
+
102
+ for line in tqdm(questions):
103
+ idx = line["sample_id"]
104
+ question_type = line["metadata"]["question_type"]
105
+ dataset_name = line["metadata"]["dataset"]
106
+ gt = line["conversations"][1]["value"]
107
+
108
+ image_files = line["image"]
109
+ qs = line["conversations"][0]["value"]
110
+ cur_prompt = args.extra_prompt + qs
111
+
112
+ args.conv_mode = "qwen_1_5"
113
+
114
+ conv = conv_templates[args.conv_mode].copy()
115
+ conv.append_message(conv.roles[0], qs)
116
+ conv.append_message(conv.roles[1], None)
117
+ prompt = conv.get_prompt()
118
+
119
+ input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
120
+ img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
121
+
122
+ image_tensors = []
123
+ for image_file in image_files:
124
+ image = Image.open(os.path.join(args.image_folder, image_file))
125
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
126
+ image_tensors.append(image_tensor.half().cuda())
127
+ # image_tensors = torch.cat(image_tensors, dim=0)
128
+
129
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
130
+ keywords = [stop_str]
131
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
132
+
133
+ with torch.inference_mode():
134
+ output_ids = model.generate(
135
+ input_ids,
136
+ images=image_tensors,
137
+ do_sample=True if args.temperature > 0 else False,
138
+ temperature=args.temperature,
139
+ top_p=args.top_p,
140
+ num_beams=args.num_beams,
141
+ # no_repeat_ngram_size=3,
142
+ max_new_tokens=1024,
143
+ use_cache=True)
144
+
145
+
146
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
147
+ outputs = outputs.strip()
148
+ if outputs.endswith(stop_str):
149
+ outputs = outputs[:-len(stop_str)]
150
+ outputs = outputs.strip()
151
+
152
+ ans_id = shortuuid.uuid()
153
+ ans_file.write(json.dumps({
154
+ "dataset": dataset_name,
155
+ "sample_id": idx,
156
+ "prompt": cur_prompt,
157
+ "pred_response": outputs,
158
+ "gt_response": gt,
159
+ "shortuuid": ans_id,
160
+ "model_id": model_name,
161
+ "question_type": question_type,
162
+ }) + "\n")
163
+ ans_file.flush()
164
+
165
+ if len(line["conversations"]) > 2:
166
+
167
+ for i in range(2, len(line["conversations"]), 2):
168
+ input_ids = torch.cat((input_ids, output_ids), dim=1)
169
+
170
+ gt = line["conversations"][i + 1]["value"]
171
+ qs = line["conversations"][i]["value"]
172
+ cur_prompt = args.extra_prompt + qs
173
+
174
+ args.conv_mode = "qwen_1_5"
175
+
176
+ conv = conv_templates[args.conv_mode].copy()
177
+ conv.append_message(conv.roles[0], qs)
178
+ conv.append_message(conv.roles[1], None)
179
+ prompt = conv.get_prompt()
180
+
181
+ input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
182
+ input_ids = torch.cat((input_ids, input_ids_new), dim=1)
183
+ img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
184
+
185
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
186
+ keywords = [stop_str]
187
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
188
+
189
+ with torch.inference_mode():
190
+ output_ids = model.generate(
191
+ input_ids,
192
+ images=image_tensors,
193
+ do_sample=True if args.temperature > 0 else False,
194
+ temperature=args.temperature,
195
+ top_p=args.top_p,
196
+ num_beams=args.num_beams,
197
+ # no_repeat_ngram_size=3,
198
+ max_new_tokens=1024,
199
+ use_cache=True)
200
+
201
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
202
+ outputs = outputs.strip()
203
+ if outputs.endswith(stop_str):
204
+ outputs = outputs[:-len(stop_str)]
205
+ outputs = outputs.strip()
206
+
207
+ ans_id = shortuuid.uuid()
208
+ ans_file.write(json.dumps({
209
+ "dataset": dataset_name,
210
+ "sample_id": idx,
211
+ "prompt": cur_prompt,
212
+ "pred_response": outputs,
213
+ "gt_response": gt,
214
+ "shortuuid": ans_id,
215
+ "model_id": model_name,
216
+ "question_type": question_type,
217
+ }) + "\n")
218
+ ans_file.flush()
219
+
220
+
221
+ ans_file.close()
222
+
223
+ if __name__ == "__main__":
224
+ parser = argparse.ArgumentParser()
225
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
226
+ parser.add_argument("--model-base", type=str, default=None)
227
+ parser.add_argument("--image-folder", type=str, default="")
228
+ parser.add_argument("--extra-prompt", type=str, default="")
229
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
230
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
231
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
232
+ parser.add_argument("--num-chunks", type=int, default=1)
233
+ parser.add_argument("--chunk-idx", type=int, default=0)
234
+ parser.add_argument("--temperature", type=float, default=0.2)
235
+ parser.add_argument("--top_p", type=float, default=None)
236
+ parser.add_argument("--num_beams", type=int, default=1)
237
+ parser.add_argument("--test_size", type=int, default=10000000)
238
+ args = parser.parse_args()
239
+
240
+ eval_model(args)
llava/mm_utils.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import math
5
+ import ast
6
+
7
+ import torch
8
+ from transformers import StoppingCriteria
9
+ from llava.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def resize_and_center_crop(image, shortest_edge_length):
13
+ # Calculate new dimensions and resize
14
+ aspect_ratio = float(image.width) / float(image.height)
15
+ if aspect_ratio > 1:
16
+ new_width = int(shortest_edge_length * aspect_ratio)
17
+ new_height = shortest_edge_length
18
+ else:
19
+ new_width = shortest_edge_length
20
+ new_height = int(shortest_edge_length / aspect_ratio)
21
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
22
+
23
+ # Calculate the position and perform the center crop
24
+ left = (new_width - shortest_edge_length) / 2
25
+ top = (new_height - shortest_edge_length) / 2
26
+ right = (new_width + shortest_edge_length) / 2
27
+ bottom = (new_height + shortest_edge_length) / 2
28
+ cropped_image = resized_image.crop((left, top, right, bottom))
29
+
30
+ return cropped_image
31
+
32
+
33
+ def auto_pad_images(image, grid_params):
34
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
35
+ assert len(grid_params) > 0, "Grid parameters should not be empty"
36
+
37
+ # Step 1: Calculate and find the closest aspect ratio
38
+ input_width, input_height = image.size
39
+ input_aspect_ratio = input_width / input_height
40
+ candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
41
+ closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
42
+
43
+ candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
44
+
45
+ target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
46
+
47
+ resize_width, resize_height = target_resolution
48
+ if input_width > input_height:
49
+ resize_height = int(resize_width / input_aspect_ratio)
50
+ else:
51
+ resize_width = int(resize_height * input_aspect_ratio)
52
+ resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
53
+
54
+ # Step 5: Pad the resized image if necessary to match the target resolution
55
+ pad_width = target_resolution[0] - resize_width
56
+ pad_height = target_resolution[1] - resize_height
57
+ padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
58
+ padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
59
+
60
+ return padded_image
61
+
62
+
63
+ def extract_patches(image, patch_size, overlap_ratio):
64
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
65
+ assert patch_size > 0, "Patch size should be greater than 0"
66
+ assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
67
+
68
+ W, H = image.size
69
+ patches = []
70
+
71
+ stride = int(patch_size * (1 - overlap_ratio))
72
+
73
+ num_patches_y = (H - patch_size) // stride + 1
74
+ num_patches_x = (W - patch_size) // stride + 1
75
+
76
+ y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
77
+ x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
78
+
79
+ for y in range(y_start, y_start + num_patches_y * stride, stride):
80
+ for x in range(x_start, x_start + num_patches_x * stride, stride):
81
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
82
+ patches.append(patch)
83
+
84
+ return patches
85
+
86
+
87
+ def process_highres_image_crop_split(image, data_args, processor=None):
88
+ crop_resolution = data_args.image_crop_resolution
89
+ split_resolution = data_args.image_split_resolution
90
+ if processor is None:
91
+ processor = data_args.image_processor
92
+ image_crop = resize_and_center_crop(image, crop_resolution)
93
+ image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
94
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
95
+ return torch.stack(image_patches, dim=0)
96
+
97
+
98
+ def process_highres_image(image, processor, grid_pinpoints):
99
+ grid_params = [int(x) for x in grid_pinpoints.split(",")]
100
+ width_height = max(image.size)
101
+ fit_grid_params = [x for x in grid_params if x >= width_height]
102
+ if len(fit_grid_params) == 0:
103
+ select_size = max(grid_params)
104
+ else:
105
+ select_size = min(fit_grid_params)
106
+ # FIXME: always select the 448
107
+ select_size = max(grid_params)
108
+ image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
109
+
110
+ # FIXME: this seems to be a bug that it always resizes instead of padding
111
+ image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
112
+ image_padded = image_padded.resize((select_size, select_size))
113
+ image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
114
+ image_patches = [image_original_resize] + image_patches
115
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
116
+ return torch.stack(image_patches, dim=0)
117
+
118
+
119
+ def select_best_resolution(original_size, possible_resolutions):
120
+ """
121
+ Selects the best resolution from a list of possible resolutions based on the original size.
122
+
123
+ Args:
124
+ original_size (tuple): The original size of the image in the format (width, height).
125
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
126
+
127
+ Returns:
128
+ tuple: The best fit resolution in the format (width, height).
129
+ """
130
+ original_width, original_height = original_size
131
+ best_fit = None
132
+ max_effective_resolution = 0
133
+ min_wasted_resolution = float("inf")
134
+
135
+ for width, height in possible_resolutions:
136
+ # Calculate the downscaled size to keep the aspect ratio
137
+ scale = min(width / original_width, height / original_height)
138
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
139
+
140
+ # Calculate effective and wasted resolutions
141
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
142
+ wasted_resolution = (width * height) - effective_resolution
143
+
144
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
145
+ max_effective_resolution = effective_resolution
146
+ min_wasted_resolution = wasted_resolution
147
+ best_fit = (width, height)
148
+
149
+ return best_fit
150
+
151
+
152
+ def resize_and_pad_image(image, target_resolution):
153
+ """
154
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
155
+
156
+ Args:
157
+ image (PIL.Image.Image): The input image.
158
+ target_resolution (tuple): The target resolution (width, height) of the image.
159
+
160
+ Returns:
161
+ PIL.Image.Image: The resized and padded image.
162
+ """
163
+ original_width, original_height = image.size
164
+ target_width, target_height = target_resolution
165
+
166
+ # Determine which dimension (width or height) to fill
167
+ scale_w = target_width / original_width
168
+ scale_h = target_height / original_height
169
+
170
+ if scale_w < scale_h:
171
+ # Width will be filled completely
172
+ new_width = target_width
173
+ new_height = min(math.ceil(original_height * scale_w), target_height)
174
+ else:
175
+ # Height will be filled completely
176
+ new_height = target_height
177
+ new_width = min(math.ceil(original_width * scale_h), target_width)
178
+
179
+ # Resize the image
180
+ resized_image = image.resize((new_width, new_height))
181
+
182
+ # Create a new image with the target size and paste the resized image onto it
183
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
184
+ paste_x = (target_width - new_width) // 2
185
+ paste_y = (target_height - new_height) // 2
186
+ new_image.paste(resized_image, (paste_x, paste_y))
187
+
188
+ return new_image
189
+
190
+
191
+ def divide_to_patches(image, patch_size):
192
+ """
193
+ Divides an image into patches of a specified size.
194
+
195
+ Args:
196
+ image (PIL.Image.Image): The input image.
197
+ patch_size (int): The size of each patch.
198
+
199
+ Returns:
200
+ list: A list of PIL.Image.Image objects representing the patches.
201
+ """
202
+ patches = []
203
+ width, height = image.size
204
+ for i in range(0, height, patch_size):
205
+ for j in range(0, width, patch_size):
206
+ box = (j, i, j + patch_size, i + patch_size)
207
+ patch = image.crop(box)
208
+ patches.append(patch)
209
+
210
+ return patches
211
+
212
+
213
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
214
+ """
215
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
216
+
217
+ Args:
218
+ image_size (tuple): The size of the input image in the format (width, height).
219
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
220
+ patch_size (int): The size of each image patch.
221
+
222
+ Returns:
223
+ tuple: The shape of the image patch grid in the format (width, height).
224
+ """
225
+ if isinstance(grid_pinpoints, str):
226
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
227
+ grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
228
+ grid_pinpoints = [[int(x) * patch_size for x in item.split(",")] for item in grid_pinpoints]
229
+
230
+ if type(grid_pinpoints) is list:
231
+ possible_resolutions = grid_pinpoints
232
+ else:
233
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
234
+ width, height = select_best_resolution(image_size, possible_resolutions)
235
+ return width // patch_size, height // patch_size
236
+
237
+
238
+ def process_anyres_image(image, processor, grid_pinpoints):
239
+ """
240
+ Process an image with variable resolutions.
241
+
242
+ Args:
243
+ image (PIL.Image.Image): The input image to be processed.
244
+ processor: The image processor object.
245
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
246
+
247
+ Returns:
248
+ torch.Tensor: A tensor containing the processed image patches.
249
+ """
250
+ # Convert grid_pinpoints from string to list
251
+ if isinstance(grid_pinpoints, str):
252
+ vis_encoder_size = processor.size[0]
253
+ assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]"
254
+ grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
255
+ grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints]
256
+
257
+ if type(grid_pinpoints) is list:
258
+ possible_resolutions = grid_pinpoints
259
+ else:
260
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
261
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
262
+ image_padded = resize_and_pad_image(image, best_resolution)
263
+
264
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
265
+
266
+ # FIXME: this seems to be a bug that it resizes instead of pad.
267
+ # but to keep it consistent with previous, i will keep it as it is
268
+ # TODO: uncomment below to ablate with the padding
269
+ if isinstance(processor.size, dict):
270
+ shortest_edge = processor.size["shortest_edge"]
271
+ else:
272
+ shortest_edge = min(processor.size)
273
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
274
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
275
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
276
+
277
+ image_patches = [image_original_resize] + patches
278
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
279
+ return torch.stack(image_patches, dim=0)
280
+
281
+
282
+ def load_image_from_base64(image):
283
+ return Image.open(BytesIO(base64.b64decode(image)))
284
+
285
+
286
+ def expand2square(pil_img, background_color):
287
+ width, height = pil_img.size
288
+ if width == height:
289
+ return pil_img
290
+ elif width > height:
291
+ result = Image.new(pil_img.mode, (width, width), background_color)
292
+ result.paste(pil_img, (0, (width - height) // 2))
293
+ return result
294
+ else:
295
+ result = Image.new(pil_img.mode, (height, height), background_color)
296
+ result.paste(pil_img, ((height - width) // 2, 0))
297
+ return result
298
+
299
+
300
+ def process_images(images, image_processor, model_cfg):
301
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
302
+ new_images = []
303
+ if image_aspect_ratio == "highres":
304
+ for image in images:
305
+ image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
306
+ new_images.append(image)
307
+ elif image_aspect_ratio == "anyres":
308
+ for image in images:
309
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
310
+ new_images.append(image)
311
+ elif image_aspect_ratio == "crop_split":
312
+ for image in images:
313
+ image = process_highres_image_crop_split(image, model_cfg, image_processor)
314
+ new_images.append(image)
315
+ elif image_aspect_ratio == "pad":
316
+ for image in images:
317
+ image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
318
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
319
+ new_images.append(image)
320
+ else:
321
+ return image_processor(images, return_tensors="pt")["pixel_values"]
322
+ if all(x.shape == new_images[0].shape for x in new_images):
323
+ new_images = torch.stack(new_images, dim=0)
324
+ return new_images
325
+
326
+
327
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
328
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
329
+
330
+ def insert_separator(X, sep):
331
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
332
+
333
+ input_ids = []
334
+ offset = 0
335
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
336
+ offset = 1
337
+ input_ids.append(prompt_chunks[0][0])
338
+
339
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
340
+ input_ids.extend(x[offset:])
341
+
342
+ if return_tensors is not None:
343
+ if return_tensors == "pt":
344
+ return torch.tensor(input_ids, dtype=torch.long)
345
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
346
+ return input_ids
347
+
348
+
349
+ def get_model_name_from_path(model_path):
350
+ model_path = model_path.strip("/")
351
+ model_paths = model_path.split("/")
352
+ if model_paths[-1].startswith("checkpoint-"):
353
+ return model_paths[-2] + "_" + model_paths[-1]
354
+ else:
355
+ return model_paths[-1]
356
+
357
+
358
+ class KeywordsStoppingCriteria(StoppingCriteria):
359
+ def __init__(self, keywords, tokenizer, input_ids):
360
+ self.keywords = keywords
361
+ self.keyword_ids = []
362
+ for keyword in keywords:
363
+ cur_keyword_ids = tokenizer(keyword).input_ids
364
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
365
+ cur_keyword_ids = cur_keyword_ids[1:]
366
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
367
+ self.tokenizer = tokenizer
368
+ self.start_len = input_ids.shape[1]
369
+
370
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
371
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
372
+ offset = min(output_ids.shape[1] - self.start_len, 3)
373
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
374
+ for keyword_id in self.keyword_ids:
375
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
376
+ return True
377
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
378
+ for keyword in self.keywords:
379
+ if keyword in outputs:
380
+ return True
381
+ return False
llava/model/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ AVAILABLE_MODELS = {
4
+ "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
5
+ "llava_gemma": "LlavaGemmaForCausalLM, LlavaGemmaConfig",
6
+ "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
7
+ # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
8
+ "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
9
+ "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
10
+ # Add other models as needed
11
+ }
12
+
13
+ for model_name, model_classes in AVAILABLE_MODELS.items():
14
+ try:
15
+ exec(f"from .language_model.{model_name} import {model_classes}")
16
+ except ImportError:
17
+ # import traceback
18
+ # traceback.print_exc()
19
+ print(f"Failed to import {model_name} from llava.language_model.{model_name}")
20
+ pass
llava/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (740 Bytes). View file
 
llava/model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (6.9 kB). View file
 
llava/model/__pycache__/llava_arch.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
llava/model/apply_delta.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from llava import LlavaLlamaForCausalLM
12
+
13
+
14
+ def apply_delta(base_model_path, target_model_path, delta_path):
15
+ print("Loading base model")
16
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31
+ bparam = base.state_dict()[name]
32
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
33
+
34
+ print("Saving target model")
35
+ delta.save_pretrained(target_model_path)
36
+ delta_tokenizer.save_pretrained(target_model_path)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument("--base-model-path", type=str, required=True)
42
+ parser.add_argument("--target-model-path", type=str, required=True)
43
+ parser.add_argument("--delta-path", type=str, required=True)
44
+
45
+ args = parser.parse_args()
46
+
47
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llava/model/builder.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from llava.model import *
23
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+ from llava.utils import rank0_print
25
+
26
+
27
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, **kwargs):
28
+ kwargs = {"device_map": device_map}
29
+
30
+ if load_8bit:
31
+ kwargs["load_in_8bit"] = True
32
+ elif load_4bit:
33
+ kwargs["load_in_4bit"] = True
34
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
35
+ else:
36
+ kwargs["torch_dtype"] = torch.float16
37
+
38
+ if customized_config is not None:
39
+ kwargs["config"] = customized_config
40
+
41
+ if "llava" in model_name.lower():
42
+ # Load LLaVA model
43
+ if "lora" in model_name.lower() and model_base is None:
44
+ warnings.warn(
45
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
46
+ )
47
+ if "lora" in model_name.lower() and model_base is not None:
48
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
49
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
50
+ rank0_print("Loading LLaVA from base model...")
51
+ if "mixtral" in model_name.lower():
52
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
53
+
54
+ lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
55
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
57
+ elif "mistral" in model_name.lower():
58
+ from llava.model.language_model.llava_mistral import LlavaMistralConfig
59
+
60
+ lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
61
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
62
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
63
+ elif "gemma" in model_name.lower():
64
+ from llava.model.language_model.llava_gemma import LlavaGemmaConfig
65
+
66
+ lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
67
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
68
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
69
+ else:
70
+ from llava.model.language_model.llava_llama import LlavaConfig
71
+
72
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
73
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
74
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
75
+
76
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
77
+ if model.lm_head.weight.shape[0] != token_num:
78
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
79
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
80
+
81
+ rank0_print("Loading additional LLaVA weights...")
82
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
83
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
84
+ else:
85
+ # this is probably from HF Hub
86
+ from huggingface_hub import hf_hub_download
87
+
88
+ def load_from_hf(repo_id, filename, subfolder=None):
89
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
90
+ return torch.load(cache_file, map_location="cpu")
91
+
92
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
93
+ non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
94
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
95
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
96
+ model.load_state_dict(non_lora_trainables, strict=False)
97
+
98
+ from peft import PeftModel
99
+
100
+ rank0_print("Loading LoRA weights...")
101
+ model = PeftModel.from_pretrained(model, model_path)
102
+ rank0_print("Merging LoRA weights...")
103
+ model = model.merge_and_unload()
104
+ rank0_print("Model is loaded...")
105
+ elif model_base is not None:
106
+ # this may be mm projector only
107
+ rank0_print(f"Loading LLaVA from base model {model_base}...")
108
+ if "mixtral" in model_name.lower():
109
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
110
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
111
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
112
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
113
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
114
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
115
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
116
+ elif "gemma" in model_name.lower():
117
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
118
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
119
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
120
+ elif (
121
+ "wizardlm-2" in model_name.lower()
122
+ and "vicuna" in model_name.lower()
123
+ or "llama" in model_name.lower()
124
+ or "yi" in model_name.lower()
125
+ or "nous-hermes" in model_name.lower()
126
+ or "llava-v1.6-34b" in model_name.lower()
127
+ or "llava-v1.5" in model_name.lower()
128
+ ):
129
+ from llava.model.language_model.llava_llama import LlavaConfig
130
+
131
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
132
+ if customized_config is None:
133
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
134
+ if "v1.5" in model_name.lower():
135
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
136
+ else:
137
+ llava_cfg = customized_config
138
+
139
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
140
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
141
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
142
+ else:
143
+ raise ValueError(f"Model {model_name} not supported")
144
+
145
+ mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
146
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
147
+ model.load_state_dict(mm_projector_weights, strict=False)
148
+ else:
149
+ rank0_print(f"Loaded LLaVA model: {model_path}")
150
+ if "mixtral" in model_name.lower():
151
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
152
+ model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
153
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
154
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
155
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
156
+ elif (
157
+ "wizardlm-2" in model_name.lower()
158
+ and "vicuna" in model_name.lower()
159
+ or "llama" in model_name.lower()
160
+ or "yi" in model_name.lower()
161
+ or "nous-hermes" in model_name.lower()
162
+ or "llava-v1.6-34b" in model_name.lower()
163
+ or "llava-v1.5" in model_name.lower()
164
+ ):
165
+ from llava.model.language_model.llava_llama import LlavaConfig
166
+
167
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
168
+ if customized_config is None:
169
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
170
+ if "v1.5" in model_name.lower():
171
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
172
+ else:
173
+ llava_cfg = customized_config
174
+
175
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
176
+ elif "qwen" in model_name.lower():
177
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
178
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
179
+ elif "gemma" in model_name.lower():
180
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
181
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
182
+ model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
183
+ else:
184
+ rank0_print("\n\n\nWarning : No matching llava architecture, auto load llava_llama. If it is not intended, specify it in model_name\n\n\n")
185
+ try:
186
+ from llava.model.language_model.llava_llama import LlavaConfig
187
+
188
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
189
+ if customized_config is None:
190
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
191
+ if "v1.5" in model_path.lower():
192
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
193
+ else:
194
+ llava_cfg = customized_config
195
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
196
+ except:
197
+ raise ValueError(f"Model {model_name} not supported")
198
+
199
+ else:
200
+ # Load language model
201
+ if model_base is not None:
202
+ # PEFT model
203
+ from peft import PeftModel
204
+
205
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
206
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
207
+ print(f"Loading LoRA weights from {model_path}")
208
+ model = PeftModel.from_pretrained(model, model_path)
209
+ print(f"Merging weights")
210
+ model = model.merge_and_unload()
211
+ print("Convert to FP16...")
212
+ model.to(torch.float16)
213
+ else:
214
+ use_fast = False
215
+ if "mpt" in model_name.lower().replace("prompt", ""):
216
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
217
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
218
+ else:
219
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
220
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
221
+
222
+ rank0_print(f"Model Class: {model.__class__.__name__}")
223
+ image_processor = None
224
+
225
+ if "llava" in model_name.lower():
226
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
227
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
228
+ if mm_use_im_patch_token:
229
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
230
+ if mm_use_im_start_end:
231
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
232
+ model.resize_token_embeddings(len(tokenizer))
233
+
234
+ vision_tower = model.get_vision_tower()
235
+ if not vision_tower.is_loaded:
236
+ vision_tower.load_model(device_map=device_map)
237
+ if device_map != "auto":
238
+ vision_tower.to(device="cuda", dtype=torch.float16)
239
+ image_processor = vision_tower.image_processor
240
+
241
+ if hasattr(model.config, "max_sequence_length"):
242
+ context_len = model.config.max_sequence_length
243
+ elif hasattr(model.config, "max_position_embeddings"):
244
+ context_len = model.config.max_position_embeddings
245
+ elif hasattr(model.config, "tokenizer_model_max_length"):
246
+ context_len = model.config.tokenizer_model_max_length
247
+ else:
248
+ context_len = 2048
249
+
250
+ return tokenizer, model, image_processor, context_len
llava/model/consolidate.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model import *
11
+ from llava.model.utils import auto_upgrade
12
+
13
+
14
+ def consolidate_ckpt(src_path, dst_path):
15
+ print("Loading model")
16
+ auto_upgrade(src_path)
17
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
19
+ src_model.save_pretrained(dst_path)
20
+ src_tokenizer.save_pretrained(dst_path)
21
+
22
+
23
+ if __name__ == "__main__":
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--src", type=str, required=True)
26
+ parser.add_argument("--dst", type=str, required=True)
27
+
28
+ args = parser.parse_args()
29
+
30
+ consolidate_ckpt(args.src, args.dst)
llava/model/language_model/__pycache__/llava_gemma.cpython-310.pyc ADDED
Binary file (3.79 kB). View file
 
llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc ADDED
Binary file (3.98 kB). View file
 
llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
llava/model/language_model/__pycache__/llava_qwen.cpython-310.pyc ADDED
Binary file (3.9 kB). View file
 
llava/model/language_model/llava_gemma.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaGemmaConfig(GemmaConfig):
31
+ model_type = "llava_gemma"
32
+
33
+
34
+ class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
35
+ config_class = LlavaGemmaConfig
36
+
37
+ def __init__(self, config: GemmaConfig):
38
+ super(LlavaGemmaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaGemmaConfig
43
+
44
+ def __init__(self, config):
45
+ super(GemmaForCausalLM, self).__init__(config)
46
+ self.model = LlavaGemmaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ images: Optional[torch.FloatTensor] = None,
68
+ image_sizes: Optional[List[List[int]]] = None,
69
+ return_dict: Optional[bool] = None,
70
+ cache_position: Optional[torch.LongTensor] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
75
+
76
+ return super().forward(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_values=past_key_values,
81
+ inputs_embeds=inputs_embeds,
82
+ labels=labels,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ cache_position=cache_position,
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate(
92
+ self,
93
+ inputs: Optional[torch.Tensor] = None,
94
+ images: Optional[torch.Tensor] = None,
95
+ image_sizes: Optional[torch.Tensor] = None,
96
+ **kwargs,
97
+ ) -> Union[GenerateOutput, torch.LongTensor]:
98
+ position_ids = kwargs.pop("position_ids", None)
99
+ attention_mask = kwargs.pop("attention_mask", None)
100
+ if "inputs_embeds" in kwargs:
101
+ raise NotImplementedError("`inputs_embeds` is not supported")
102
+
103
+ if images is not None:
104
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
105
+ else:
106
+ inputs_embeds = self.get_model().embed_tokens(inputs)
107
+
108
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
109
+
110
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
111
+ images = kwargs.pop("images", None)
112
+ image_sizes = kwargs.pop("image_sizes", None)
113
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
114
+ if images is not None:
115
+ inputs["images"] = images
116
+ if image_sizes is not None:
117
+ inputs["image_sizes"] = image_sizes
118
+ return inputs
119
+
120
+
121
+ AutoConfig.register("llava_gemma", LlavaGemmaConfig)
122
+ AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
22
+
23
+ # , LlamaModel, LlamaForCausalLM, GenerationConfig
24
+ # from .modeling_llama import LlamaModel, LlamaForCausalLM
25
+ from transformers import LlamaModel, LlamaForCausalLM
26
+ from transformers.modeling_outputs import CausalLMOutputWithPast
27
+ from transformers.generation.utils import GenerateOutput
28
+
29
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
30
+
31
+
32
+ class LlavaConfig(LlamaConfig):
33
+ model_type = "llava_llama"
34
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
35
+ max_new_tokens: int = 1024
36
+ do_sample: bool = False
37
+ top_p: Optional[float] = None
38
+ rope_scaling: Optional[dict] = {}
39
+
40
+
41
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config: LlamaConfig):
45
+ super(LlavaLlamaModel, self).__init__(config)
46
+
47
+
48
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
49
+ config_class = LlavaConfig
50
+
51
+ def __init__(self, config):
52
+ LlamaForCausalLM.__init__(self, config)
53
+
54
+ # configure default generation settings
55
+ config.model_type = "llava_llama"
56
+ config.rope_scaling = None
57
+
58
+ self.model = LlavaLlamaModel(config)
59
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
60
+ # Initialize weights and apply final processing
61
+ self.post_init()
62
+
63
+ def get_model(self):
64
+ return self.model
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: torch.LongTensor = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
72
+ inputs_embeds: Optional[torch.FloatTensor] = None,
73
+ labels: Optional[torch.LongTensor] = None,
74
+ use_cache: Optional[bool] = None,
75
+ output_attentions: Optional[bool] = None,
76
+ output_hidden_states: Optional[bool] = None,
77
+ images: Optional[torch.FloatTensor] = None,
78
+ image_sizes: Optional[List[List[int]]] = None,
79
+ return_dict: Optional[bool] = None,
80
+ cache_position=None,
81
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
82
+
83
+ if inputs_embeds is None:
84
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
85
+
86
+ return super().forward(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ position_ids=position_ids,
90
+ past_key_values=past_key_values,
91
+ inputs_embeds=inputs_embeds,
92
+ labels=labels,
93
+ use_cache=use_cache,
94
+ output_attentions=output_attentions,
95
+ output_hidden_states=output_hidden_states,
96
+ return_dict=return_dict,
97
+ )
98
+
99
+ @torch.no_grad()
100
+ def generate(
101
+ self,
102
+ inputs: Optional[torch.Tensor] = None,
103
+ images: Optional[torch.Tensor] = None,
104
+ image_sizes: Optional[torch.Tensor] = None,
105
+ **kwargs,
106
+ ) -> Union[GenerateOutput, torch.LongTensor]:
107
+ position_ids = kwargs.pop("position_ids", None)
108
+ attention_mask = kwargs.pop("attention_mask", None)
109
+ if "inputs_embeds" in kwargs:
110
+ raise NotImplementedError("`inputs_embeds` is not supported")
111
+
112
+ if images is not None:
113
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
114
+ else:
115
+ inputs_embeds = self.get_model().embed_tokens(inputs)
116
+
117
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
118
+
119
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
120
+ images = kwargs.pop("images", None)
121
+ image_sizes = kwargs.pop("image_sizes", None)
122
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
123
+ if images is not None:
124
+ inputs["images"] = images
125
+ if image_sizes is not None:
126
+ inputs["image_sizes"] = image_sizes
127
+ return inputs
128
+
129
+
130
+ AutoConfig.register("llava_llama", LlavaConfig)
131
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaMistralConfig(MistralConfig):
31
+ model_type = "llava_mistral"
32
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
33
+ max_new_tokens: int = 1024
34
+ do_sample: bool = False
35
+ top_p: Optional[float] = None
36
+
37
+
38
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
39
+ config_class = LlavaMistralConfig
40
+
41
+ def __init__(self, config: MistralConfig):
42
+ super(LlavaMistralModel, self).__init__(config)
43
+
44
+
45
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
46
+ config_class = LlavaMistralConfig
47
+
48
+ def __init__(self, config):
49
+ super(MistralForCausalLM, self).__init__(config)
50
+
51
+ config.model_type = "llava_mistral"
52
+ config.rope_scaling = None
53
+
54
+ self.model = LlavaMistralModel(config)
55
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
56
+ # Initialize weights and apply final processing
57
+ self.post_init()
58
+
59
+ def get_model(self):
60
+ return self.model
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ images: Optional[torch.FloatTensor] = None,
74
+ image_sizes: Optional[List[List[int]]] = None,
75
+ return_dict: Optional[bool] = None,
76
+ cache_position=None,
77
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
78
+
79
+ if inputs_embeds is None:
80
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
81
+
82
+ return super().forward(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ position_ids=position_ids,
86
+ past_key_values=past_key_values,
87
+ inputs_embeds=inputs_embeds,
88
+ labels=labels,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=output_hidden_states,
92
+ return_dict=return_dict,
93
+ )
94
+
95
+ @torch.no_grad()
96
+ def generate(
97
+ self,
98
+ inputs: Optional[torch.Tensor] = None,
99
+ images: Optional[torch.Tensor] = None,
100
+ image_sizes: Optional[torch.Tensor] = None,
101
+ **kwargs,
102
+ ) -> Union[GenerateOutput, torch.LongTensor]:
103
+ position_ids = kwargs.pop("position_ids", None)
104
+ attention_mask = kwargs.pop("attention_mask", None)
105
+ if "inputs_embeds" in kwargs:
106
+ raise NotImplementedError("`inputs_embeds` is not supported")
107
+
108
+ if images is not None:
109
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
110
+ else:
111
+ inputs_embeds = self.get_model().embed_tokens(inputs)
112
+
113
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
114
+
115
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
116
+ images = kwargs.pop("images", None)
117
+ image_sizes = kwargs.pop("image_sizes", None)
118
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
119
+ if images is not None:
120
+ inputs["images"] = images
121
+ if image_sizes is not None:
122
+ inputs["image_sizes"] = image_sizes
123
+ return inputs
124
+
125
+
126
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
127
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
llava/model/language_model/llava_mixtral.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaMixtralConfig(MixtralConfig):
31
+ model_type = "llava_mixtral"
32
+
33
+
34
+ class LlavaMixtralModel(LlavaMetaModel, MixtralModel):
35
+ config_class = LlavaMixtralConfig
36
+
37
+ def __init__(self, config: MixtralConfig):
38
+ super(LlavaMixtralModel, self).__init__(config)
39
+
40
+
41
+ class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaMixtralConfig
43
+
44
+ def __init__(self, config):
45
+ super(MixtralForCausalLM, self).__init__(config)
46
+
47
+ config.model_type = "llava_mixtral"
48
+ config.rope_scaling = None
49
+ self.model = LlavaMixtralModel(config)
50
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ cache_position=None,
72
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
73
+
74
+ if inputs_embeds is None:
75
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
76
+
77
+ return super().forward(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ labels=labels,
84
+ use_cache=use_cache,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict,
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate(
92
+ self,
93
+ inputs: Optional[torch.Tensor] = None,
94
+ images: Optional[torch.Tensor] = None,
95
+ image_sizes: Optional[torch.Tensor] = None,
96
+ **kwargs,
97
+ ) -> Union[GenerateOutput, torch.LongTensor]:
98
+ position_ids = kwargs.pop("position_ids", None)
99
+ attention_mask = kwargs.pop("attention_mask", None)
100
+ if "inputs_embeds" in kwargs:
101
+ raise NotImplementedError("`inputs_embeds` is not supported")
102
+
103
+ if images is not None:
104
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
105
+ else:
106
+ inputs_embeds = self.get_model().embed_tokens(inputs)
107
+
108
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
109
+
110
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
111
+ images = kwargs.pop("images", None)
112
+ image_sizes = kwargs.pop("image_sizes", None)
113
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
114
+ if images is not None:
115
+ inputs["images"] = images
116
+ if image_sizes is not None:
117
+ inputs["image_sizes"] = image_sizes
118
+ return inputs
119
+
120
+
121
+ AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
122
+ AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
21
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
22
+
23
+
24
+ class LlavaMptConfig(MptConfig):
25
+ model_type = "llava_mpt"
26
+
27
+
28
+ class LlavaMptModel(LlavaMetaModel, MptModel):
29
+ config_class = LlavaMptConfig
30
+
31
+ def __init__(self, config: MptConfig):
32
+ config.hidden_size = config.d_model
33
+ super(LlavaMptModel, self).__init__(config)
34
+
35
+ def embed_tokens(self, x):
36
+ return self.wte(x)
37
+
38
+
39
+ class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
40
+ config_class = LlavaMptConfig
41
+ supports_gradient_checkpointing = True
42
+
43
+ def __init__(self, config):
44
+ super(MptForCausalLM, self).__init__(config)
45
+
46
+ config.model_type = "llava_mpt"
47
+ config.rope_scaling = None
48
+ self.generation_config = GenerationConfig(
49
+ temperature=0.0,
50
+ max_new_tokens=1024,
51
+ do_sample=False,
52
+ top_p=None,
53
+ )
54
+
55
+ self.transformer = LlavaMptModel(config)
56
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+
58
+ # Initialize weights and apply final processing
59
+ self.post_init()
60
+
61
+ def get_model(self):
62
+ return self.transformer
63
+
64
+ def _set_gradient_checkpointing(self, module, value=False):
65
+ if isinstance(module, LlavaMptModel):
66
+ module.gradient_checkpointing = value
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ inputs_embeds: Optional[torch.Tensor] = None,
74
+ labels: Optional[torch.Tensor] = None,
75
+ use_cache: Optional[bool] = None,
76
+ output_attentions: Optional[bool] = None,
77
+ output_hidden_states: Optional[bool] = None,
78
+ return_dict: Optional[bool] = None,
79
+ cache_position=None,
80
+ images=None,
81
+ ):
82
+
83
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
84
+
85
+ return super().forward(
86
+ input_ids,
87
+ past_key_values=past_key_values,
88
+ attention_mask=attention_mask,
89
+ inputs_embeds=inputs_embeds,
90
+ labels=labels,
91
+ use_cache=use_cache,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
98
+ images = kwargs.pop("images", None)
99
+ _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
100
+ _inputs["images"] = images
101
+ return _inputs
102
+
103
+
104
+ AutoConfig.register("llava_mpt", LlavaMptConfig)
105
+ AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
llava/model/language_model/llava_qwen.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+
21
+ import transformers
22
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
30
+
31
+ # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
32
+ # from .qwen.configuration_qwen import QWenConfig
33
+
34
+
35
+ class LlavaQwenConfig(Qwen2Config):
36
+ model_type = "llava_qwen"
37
+
38
+
39
+ class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
40
+ config_class = LlavaQwenConfig
41
+
42
+ def __init__(self, config: Qwen2Config):
43
+ super(LlavaQwenModel, self).__init__(config)
44
+
45
+
46
+ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
47
+ config_class = LlavaQwenConfig
48
+
49
+ def __init__(self, config):
50
+ # super(Qwen2ForCausalLM, self).__init__(config)
51
+ Qwen2ForCausalLM.__init__(self, config)
52
+ config.model_type = "llava_qwen"
53
+ config.rope_scaling = None
54
+
55
+ self.model = LlavaQwenModel(config)
56
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ image_sizes: Optional[List[List[int]]] = None,
76
+ return_dict: Optional[bool] = None,
77
+ cache_position=None,
78
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
79
+
80
+ if inputs_embeds is None:
81
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
82
+
83
+ return super().forward(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ position_ids=position_ids,
87
+ past_key_values=past_key_values,
88
+ inputs_embeds=inputs_embeds,
89
+ labels=labels,
90
+ use_cache=use_cache,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=output_hidden_states,
93
+ return_dict=return_dict,
94
+ )
95
+
96
+ @torch.no_grad()
97
+ def generate(
98
+ self,
99
+ inputs: Optional[torch.Tensor] = None,
100
+ images: Optional[torch.Tensor] = None,
101
+ image_sizes: Optional[torch.Tensor] = None,
102
+ **kwargs,
103
+ ) -> Union[GenerateOutput, torch.LongTensor]:
104
+ position_ids = kwargs.pop("position_ids", None)
105
+ attention_mask = kwargs.pop("attention_mask", None)
106
+ if "inputs_embeds" in kwargs:
107
+ raise NotImplementedError("`inputs_embeds` is not supported")
108
+
109
+ if images is not None:
110
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
111
+ else:
112
+ inputs_embeds = self.get_model().embed_tokens(inputs)
113
+
114
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
115
+
116
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
117
+ images = kwargs.pop("images", None)
118
+ image_sizes = kwargs.pop("image_sizes", None)
119
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
120
+ if images is not None:
121
+ inputs["images"] = images
122
+ if image_sizes is not None:
123
+ inputs["image_sizes"] = image_sizes
124
+ return inputs
125
+
126
+
127
+ AutoConfig.register("llava_qwen", LlavaQwenConfig)
128
+ AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)
llava/model/language_model/llava_qwen_moe.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+
21
+ import transformers
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+ from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM
30
+
31
+ # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
32
+ # from .qwen.configuration_qwen import QWenConfig
33
+
34
+
35
+ class LlavaQwenMoeConfig(Qwen2MoeConfig):
36
+ model_type = "llava_qwen_moe"
37
+
38
+
39
+ class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel):
40
+ config_class = LlavaQwenMoeConfig
41
+
42
+ def __init__(self, config: Qwen2MoeConfig):
43
+ super(LlavaQwenMoeModel, self).__init__(config)
44
+
45
+
46
+ class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM):
47
+ config_class = LlavaQwenMoeConfig
48
+
49
+ def __init__(self, config):
50
+ # super(Qwen2MoeForCausalLM, self).__init__(config)
51
+ Qwen2MoeForCausalLM.__init__(self, config)
52
+ config.model_type = "llava_qwen_moe"
53
+ config.rope_scaling = None
54
+
55
+ self.model = LlavaQwenMoeModel(config)
56
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ image_sizes: Optional[List[List[int]]] = None,
76
+ return_dict: Optional[bool] = None,
77
+ cache_position=None,
78
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
79
+
80
+ if inputs_embeds is None:
81
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
82
+
83
+ return super().forward(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ position_ids=position_ids,
87
+ past_key_values=past_key_values,
88
+ inputs_embeds=inputs_embeds,
89
+ labels=labels,
90
+ use_cache=use_cache,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=output_hidden_states,
93
+ return_dict=return_dict,
94
+ )
95
+
96
+ @torch.no_grad()
97
+ def generate(
98
+ self,
99
+ inputs: Optional[torch.Tensor] = None,
100
+ images: Optional[torch.Tensor] = None,
101
+ image_sizes: Optional[torch.Tensor] = None,
102
+ **kwargs,
103
+ ) -> Union[GenerateOutput, torch.LongTensor]:
104
+ position_ids = kwargs.pop("position_ids", None)
105
+ attention_mask = kwargs.pop("attention_mask", None)
106
+ if "inputs_embeds" in kwargs:
107
+ raise NotImplementedError("`inputs_embeds` is not supported")
108
+
109
+ if images is not None:
110
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
111
+ else:
112
+ inputs_embeds = self.get_model().embed_tokens(inputs)
113
+
114
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
115
+
116
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
117
+ images = kwargs.pop("images", None)
118
+ image_sizes = kwargs.pop("image_sizes", None)
119
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
120
+ if images is not None:
121
+ inputs["images"] = images
122
+ if image_sizes is not None:
123
+ inputs["image_sizes"] = image_sizes
124
+ return inputs
125
+
126
+
127
+ AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig)
128
+ AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM)
llava/model/llava_arch.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_resampler.builder import build_vision_resampler
23
+ from .multimodal_projector.builder import build_vision_projector
24
+
25
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
26
+
27
+ from llava.mm_utils import get_anyres_image_grid_shape
28
+ from llava.utils import rank0_print
29
+
30
+
31
+ class LlavaMetaModel:
32
+
33
+ def __init__(self, config):
34
+ super(LlavaMetaModel, self).__init__(config)
35
+
36
+ if hasattr(config, "mm_vision_tower"):
37
+ delay_load = getattr(config, "delay_load", False)
38
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
39
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
40
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
41
+
42
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
43
+ self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
44
+
45
+ def get_vision_tower(self):
46
+ vision_tower = getattr(self, "vision_tower", None)
47
+ if type(vision_tower) is list:
48
+ vision_tower = vision_tower[0]
49
+ return vision_tower
50
+
51
+ def initialize_vision_modules(self, model_args, fsdp=None):
52
+ vision_tower = model_args.vision_tower
53
+ mm_vision_select_layer = model_args.mm_vision_select_layer
54
+ mm_vision_select_feature = model_args.mm_vision_select_feature
55
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
56
+ mm_patch_merge_type = model_args.mm_patch_merge_type
57
+
58
+ self.config.mm_vision_tower = vision_tower
59
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
60
+
61
+ if self.get_vision_tower() is None:
62
+ vision_tower = build_vision_tower(model_args)
63
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
64
+ for k, v in vision_resampler.config.items():
65
+ setattr(self.config, k, v)
66
+
67
+ if fsdp is not None and len(fsdp) > 0:
68
+ self.vision_tower = [vision_tower]
69
+ self.vision_resampler = [vision_resampler]
70
+ else:
71
+ self.vision_tower = vision_tower
72
+ self.vision_resampler = vision_resampler
73
+ else:
74
+ if fsdp is not None and len(fsdp) > 0:
75
+ vision_resampler = self.vision_resampler[0]
76
+ vision_tower = self.vision_tower[0]
77
+ else:
78
+ vision_resampler = self.vision_resampler
79
+ vision_tower = self.vision_tower
80
+ vision_tower.load_model()
81
+
82
+ # In case it is frozen by LoRA
83
+ for p in self.vision_resampler.parameters():
84
+ p.requires_grad = True
85
+
86
+ self.config.use_mm_proj = True
87
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
88
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
89
+ self.config.mm_vision_select_layer = mm_vision_select_layer
90
+ self.config.mm_vision_select_feature = mm_vision_select_feature
91
+ self.config.mm_patch_merge_type = mm_patch_merge_type
92
+
93
+ if getattr(self, "mm_projector", None) is None:
94
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
95
+
96
+ if "unpad" in mm_patch_merge_type:
97
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
98
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
99
+ else:
100
+ # In case it is frozen by LoRA
101
+ for p in self.mm_projector.parameters():
102
+ p.requires_grad = True
103
+
104
+ if pretrain_mm_mlp_adapter is not None:
105
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
106
+
107
+ def get_w(weights, keyword):
108
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
109
+
110
+ incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
111
+ rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
112
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
113
+ rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
114
+
115
+
116
+ def unpad_image(tensor, original_size):
117
+ """
118
+ Unpads a PyTorch tensor of a padded and resized image.
119
+
120
+ Args:
121
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
122
+ original_size (tuple): The original size of the image (height, width).
123
+
124
+ Returns:
125
+ torch.Tensor: The unpadded image tensor.
126
+ """
127
+ original_width, original_height = original_size
128
+ current_height, current_width = tensor.shape[1:]
129
+
130
+ # Compute aspect ratios
131
+ original_aspect_ratio = original_width / original_height
132
+ current_aspect_ratio = current_width / current_height
133
+
134
+ # Determine padding size and direction
135
+ if original_aspect_ratio > current_aspect_ratio:
136
+ # Padding was added to the height
137
+ scale_factor = current_width / original_width
138
+ new_height = int(original_height * scale_factor)
139
+ padding = (current_height - new_height) // 2
140
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
141
+ else:
142
+ # Padding was added to the width
143
+ scale_factor = current_height / original_height
144
+ new_width = int(original_width * scale_factor)
145
+ padding = (current_width - new_width) // 2
146
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
147
+
148
+ return unpadded_tensor
149
+
150
+
151
+ class LlavaMetaForCausalLM(ABC):
152
+
153
+ @abstractmethod
154
+ def get_model(self):
155
+ pass
156
+
157
+ def get_vision_tower(self):
158
+ return self.get_model().get_vision_tower()
159
+
160
+ def encode_images(self, images):
161
+ image_features = self.get_model().get_vision_tower()(images)
162
+ image_features = self.get_model().vision_resampler(image_features, images=images)
163
+ image_features = self.get_model().mm_projector(image_features)
164
+ return image_features
165
+
166
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None):
167
+ vision_tower = self.get_vision_tower()
168
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
169
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
170
+
171
+ if type(images) is list or images.ndim == 5:
172
+ if type(images) is list:
173
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
174
+ concat_images = torch.cat([image for image in images], dim=0)
175
+ image_features = self.encode_images(concat_images)
176
+ split_sizes = [image.shape[0] for image in images]
177
+ image_features = torch.split(image_features, split_sizes, dim=0)
178
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
179
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
180
+ if mm_patch_merge_type == "flat":
181
+ image_features = [x.flatten(0, 1) for x in image_features]
182
+ elif mm_patch_merge_type.startswith("spatial"):
183
+ new_image_features = []
184
+ for image_idx, image_feature in enumerate(image_features):
185
+ # FIXME: now assume the image is square, and split to 2x2 patches
186
+ # num_patches = h * w, where h = w = sqrt(num_patches)
187
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
188
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
189
+
190
+ if image_feature.shape[0] > 1:
191
+ base_image_feature = image_feature[0]
192
+ image_feature = image_feature[1:]
193
+ height = width = self.get_vision_tower().num_patches_per_side
194
+ assert height * width == base_image_feature.shape[0]
195
+ if image_aspect_ratio == "anyres":
196
+ if hasattr(self.get_vision_tower(), "image_size"):
197
+ vision_tower_image_size = self.get_vision_tower().image_size
198
+ else:
199
+ raise ValueError("vision_tower_image_size is not found in the vision tower.")
200
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
201
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
202
+ else:
203
+ image_feature = image_feature.view(2, 2, height, width, -1)
204
+ if "maxpool2x2" in mm_patch_merge_type:
205
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
206
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
207
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
208
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
209
+ elif "unpad" in mm_patch_merge_type:
210
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
211
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
212
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
213
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
214
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
215
+ else:
216
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
217
+ image_feature = image_feature.flatten(0, 3)
218
+ if "nobase" in mm_patch_merge_type:
219
+ pass
220
+ else:
221
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
222
+ else:
223
+ image_feature = image_feature[0]
224
+ if "unpad" in mm_patch_merge_type:
225
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
226
+ new_image_features.append(image_feature)
227
+ image_features = new_image_features
228
+ else:
229
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
230
+ else:
231
+ image_features = self.encode_images(images)
232
+
233
+ # TODO: image start / end is not implemented here to support pretraining.
234
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
235
+ raise NotImplementedError
236
+
237
+ # Let's just add dummy tensors if they do not exist,
238
+ # it is a headache to deal with None all the time.
239
+ # But it is not ideal, and if you have a better idea,
240
+ # please open an issue / submit a PR, thanks.
241
+ _labels = labels
242
+ _position_ids = position_ids
243
+ _attention_mask = attention_mask
244
+ if attention_mask is None:
245
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
246
+ else:
247
+ attention_mask = attention_mask.bool()
248
+ if position_ids is None:
249
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
250
+ if labels is None:
251
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
252
+
253
+ # remove the padding using attention_mask -- FIXME
254
+ _input_ids = input_ids
255
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
256
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
257
+
258
+ new_input_embeds = []
259
+ new_labels = []
260
+ cur_image_idx = 0
261
+ for batch_idx, cur_input_ids in enumerate(input_ids):
262
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
263
+ if num_images == 0:
264
+ cur_image_features = image_features[cur_image_idx]
265
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
266
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
267
+ new_input_embeds.append(cur_input_embeds)
268
+ new_labels.append(labels[batch_idx])
269
+ cur_image_idx += 1
270
+ continue
271
+
272
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
273
+ cur_input_ids_noim = []
274
+ cur_labels = labels[batch_idx]
275
+ cur_labels_noim = []
276
+ for i in range(len(image_token_indices) - 1):
277
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
278
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
279
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
280
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
281
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
282
+ cur_new_input_embeds = []
283
+ cur_new_labels = []
284
+
285
+ for i in range(num_images + 1):
286
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
287
+ cur_new_labels.append(cur_labels_noim[i])
288
+ if i < num_images:
289
+ cur_image_features = image_features[cur_image_idx]
290
+ cur_image_idx += 1
291
+ cur_new_input_embeds.append(cur_image_features)
292
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
293
+
294
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
295
+
296
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
297
+ cur_new_labels = torch.cat(cur_new_labels)
298
+
299
+ new_input_embeds.append(cur_new_input_embeds)
300
+ new_labels.append(cur_new_labels)
301
+
302
+ # Truncate sequences to max length as image embeddings can make the sequence longer
303
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
304
+ if tokenizer_model_max_length is not None:
305
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
306
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
307
+
308
+ # Combine them
309
+ max_len = max(x.shape[0] for x in new_input_embeds)
310
+ batch_size = len(new_input_embeds)
311
+
312
+ new_input_embeds_padded = []
313
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
314
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
315
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
316
+
317
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
318
+ cur_len = cur_new_embed.shape[0]
319
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
320
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
321
+ if cur_len > 0:
322
+ new_labels_padded[i, -cur_len:] = cur_new_labels
323
+ attention_mask[i, -cur_len:] = True
324
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
325
+ else:
326
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
327
+ if cur_len > 0:
328
+ new_labels_padded[i, :cur_len] = cur_new_labels
329
+ attention_mask[i, :cur_len] = True
330
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
331
+
332
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
333
+
334
+ if _labels is None:
335
+ new_labels = None
336
+ else:
337
+ new_labels = new_labels_padded
338
+
339
+ if _attention_mask is None:
340
+ attention_mask = None
341
+ else:
342
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
343
+
344
+ if _position_ids is None:
345
+ position_ids = None
346
+
347
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
348
+
349
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
350
+ if model_args.mm_use_im_patch_token:
351
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
352
+ self.resize_token_embeddings(len(tokenizer))
353
+
354
+ if model_args.mm_use_im_start_end:
355
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
356
+ self.resize_token_embeddings(len(tokenizer))
357
+
358
+ if num_new_tokens > 0:
359
+ input_embeddings = self.get_input_embeddings().weight.data
360
+ output_embeddings = self.get_output_embeddings().weight.data
361
+
362
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
363
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
364
+
365
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
366
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
367
+
368
+ if model_args.tune_mm_mlp_adapter:
369
+ for p in self.get_input_embeddings().parameters():
370
+ p.requires_grad = True
371
+ for p in self.get_output_embeddings().parameters():
372
+ p.requires_grad = False
373
+
374
+ if model_args.pretrain_mm_mlp_adapter:
375
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
376
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
377
+ assert num_new_tokens == 2
378
+ if input_embeddings.shape == embed_tokens_weight.shape:
379
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
380
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
381
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
382
+ else:
383
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
384
+ elif model_args.mm_use_im_patch_token:
385
+ if model_args.tune_mm_mlp_adapter:
386
+ for p in self.get_input_embeddings().parameters():
387
+ p.requires_grad = False
388
+ for p in self.get_output_embeddings().parameters():
389
+ p.requires_grad = False
llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from llava.model.utils import auto_upgrade
12
+
13
+
14
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
15
+ print("Loading base model")
16
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31
+ bparam = base.state_dict()[name]
32
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (772 Bytes). View file
 
llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc ADDED
Binary file (21.8 kB). View file
 
llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .siglip_encoder import SigLipVisionTower
4
+
5
+
6
+ def build_vision_tower(vision_tower_cfg, **kwargs):
7
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
8
+ is_absolute_path_exists = os.path.exists(vision_tower)
9
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
10
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
11
+ elif "siglip" in vision_tower:
12
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
13
+
14
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
20
+ # TODO: better detector is needed.
21
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
22
+ self.load_model()
23
+ else:
24
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
25
+
26
+ def load_model(self, device_map=None):
27
+ if self.is_loaded:
28
+ print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
29
+ return
30
+
31
+ # import pdb; pdb.set_trace()
32
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
33
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
34
+ self.vision_tower.requires_grad_(False)
35
+
36
+ self.is_loaded = True
37
+
38
+ def feature_select(self, image_forward_outs):
39
+ select_feature_type = self.select_feature
40
+
41
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
42
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
43
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
44
+ select_feature_type = select_feature_type.replace("slicefour_", "")
45
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
46
+ select_layers = [-2, -5, -8, -11, 6]
47
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
48
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
49
+ else:
50
+ image_features = image_forward_outs.hidden_states[self.select_layer]
51
+
52
+ if select_feature_type == "patch":
53
+ image_features = image_features[:, 1:]
54
+ elif select_feature_type == "cls_patch":
55
+ image_features = image_features
56
+ else:
57
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
58
+ return image_features
59
+
60
+ def forward(self, images):
61
+ if type(images) is list:
62
+ image_features = []
63
+ for image in images:
64
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
65
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
66
+ image_features.append(image_feature)
67
+ else:
68
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
69
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
70
+
71
+ return image_features
72
+
73
+ @property
74
+ def dummy_feature(self):
75
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
76
+
77
+ @property
78
+ def dtype(self):
79
+ return self.vision_tower.dtype
80
+
81
+ @property
82
+ def device(self):
83
+ return self.vision_tower.device
84
+
85
+ @property
86
+ def config(self):
87
+ if self.is_loaded:
88
+ return self.vision_tower.config
89
+ else:
90
+ return self.cfg_only
91
+
92
+ @property
93
+ def hidden_size(self):
94
+ _hidden_size = self.config.hidden_size
95
+ if "slicefour" in self.select_feature:
96
+ _hidden_size *= 4
97
+ if "slice_m25811_f6" in self.select_feature:
98
+ _hidden_size *= 5
99
+ return _hidden_size
100
+
101
+ @property
102
+ def num_patches_per_side(self):
103
+ return self.config.image_size // self.config.patch_size
104
+
105
+ @property
106
+ def num_patches(self):
107
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
108
+ if "cls_patch" in self.select_feature:
109
+ _num_patches += 1
110
+ return _num_patches
111
+
112
+ @property
113
+ def image_size(self):
114
+ return self.config.image_size
llava/model/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
3
+ """
4
+
5
+ from typing import Optional, Tuple, Union, Dict
6
+ from dataclasses import dataclass
7
+ from functools import partial, reduce
8
+ from PIL import Image
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ import os
13
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
14
+ from transformers.image_transforms import (
15
+ convert_to_rgb,
16
+ normalize,
17
+ rescale,
18
+ resize,
19
+ to_channel_dimension_format,
20
+ )
21
+ from transformers.image_utils import (
22
+ ChannelDimension,
23
+ PILImageResampling,
24
+ to_numpy_array,
25
+ )
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers import PretrainedConfig
30
+ from transformers.utils import ModelOutput
31
+ from llava.utils import rank0_print
32
+
33
+
34
+ class SigLipImageProcessor:
35
+ def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
36
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
37
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
38
+
39
+ self.image_mean = image_mean
40
+ self.image_std = image_std
41
+ self.size = size
42
+ self.resample = resample
43
+ self.rescale_factor = rescale_factor
44
+ self.data_format = data_format
45
+ self.crop_size = crop_size
46
+
47
+ def preprocess(self, images, return_tensors):
48
+ if isinstance(images, Image.Image):
49
+ images = [images]
50
+ else:
51
+ assert isinstance(images, list)
52
+
53
+ transforms = [
54
+ convert_to_rgb,
55
+ to_numpy_array,
56
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
57
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
58
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
59
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
60
+ ]
61
+
62
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
63
+ data = {"pixel_values": images}
64
+
65
+ return BatchFeature(data=data, tensor_type=return_tensors)
66
+
67
+
68
+ class SigLipVisionConfig(PretrainedConfig):
69
+ model_type = "siglip_vision_model"
70
+
71
+ def __init__(
72
+ self,
73
+ hidden_size=1152,
74
+ image_mean=(0.5, 0.5, 0.5),
75
+ intermediate_size=4304,
76
+ num_hidden_layers=27,
77
+ num_attention_heads=16,
78
+ num_channels=3,
79
+ image_size=384,
80
+ patch_size=14,
81
+ hidden_act="gelu_pytorch_tanh",
82
+ layer_norm_eps=1e-6,
83
+ attention_dropout=0.0,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(**kwargs)
87
+
88
+ self.hidden_size = hidden_size
89
+ self.intermediate_size = intermediate_size
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ self.num_channels = num_channels
93
+ self.patch_size = patch_size
94
+ self.image_size = image_size
95
+ self.attention_dropout = attention_dropout
96
+ self.layer_norm_eps = layer_norm_eps
97
+ self.hidden_act = hidden_act
98
+ self.image_mean = image_mean
99
+
100
+ @classmethod
101
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
102
+ cls._set_token_in_kwargs(kwargs)
103
+
104
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
105
+
106
+ # get the vision config dict if we are loading from SigLipConfig
107
+ if config_dict.get("model_type") == "siglip":
108
+ config_dict = config_dict["vision_config"]
109
+
110
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
111
+ print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.")
112
+
113
+ return cls.from_dict(config_dict, **kwargs)
114
+
115
+
116
+ @dataclass
117
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
118
+ class SigLipVisionModelOutput(ModelOutput):
119
+ """
120
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
121
+
122
+ Args:
123
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
124
+ The image embeddings obtained by applying the projection layer to the pooler_output.
125
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
126
+ Sequence of hidden-states at the output of the last layer of the model.
127
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
128
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
129
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
130
+
131
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
132
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
133
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
134
+ sequence_length)`.
135
+
136
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
137
+ heads.
138
+ """
139
+
140
+ image_embeds: Optional[torch.FloatTensor] = None
141
+ last_hidden_state: torch.FloatTensor = None
142
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
143
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
144
+
145
+
146
+ class SigLipVisionEmbeddings(nn.Module):
147
+ def __init__(self, config: SigLipVisionConfig):
148
+ super().__init__()
149
+ self.config = config
150
+ self.embed_dim = config.hidden_size
151
+ self.image_size = config.image_size
152
+ self.patch_size = config.patch_size
153
+
154
+ self.patch_embedding = nn.Conv2d(
155
+ in_channels=config.num_channels,
156
+ out_channels=self.embed_dim,
157
+ kernel_size=self.patch_size,
158
+ stride=self.patch_size,
159
+ padding="valid",
160
+ )
161
+
162
+ self.num_patches = (self.image_size // self.patch_size) ** 2
163
+ self.num_positions = self.num_patches
164
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
165
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
166
+
167
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
168
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
169
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
170
+
171
+ embeddings = embeddings + self.position_embedding(self.position_ids)
172
+ return embeddings
173
+
174
+
175
+ class SigLipAttention(nn.Module):
176
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
177
+
178
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
179
+ def __init__(self, config):
180
+ super().__init__()
181
+ self.config = config
182
+ self.embed_dim = config.hidden_size
183
+ self.num_heads = config.num_attention_heads
184
+ self.head_dim = self.embed_dim // self.num_heads
185
+ if self.head_dim * self.num_heads != self.embed_dim:
186
+ raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
187
+ self.scale = self.head_dim**-0.5
188
+ self.dropout = config.attention_dropout
189
+
190
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
191
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
192
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
193
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
194
+
195
+ def forward(
196
+ self,
197
+ hidden_states: torch.Tensor,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ output_attentions: Optional[bool] = False,
200
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
201
+ """Input shape: Batch x Time x Channel"""
202
+
203
+ batch_size, q_len, _ = hidden_states.size()
204
+
205
+ query_states = self.q_proj(hidden_states)
206
+ key_states = self.k_proj(hidden_states)
207
+ value_states = self.v_proj(hidden_states)
208
+
209
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
210
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
211
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
212
+
213
+ k_v_seq_len = key_states.shape[-2]
214
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
215
+
216
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
217
+ raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")
218
+
219
+ if attention_mask is not None:
220
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
221
+ raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
222
+ attn_weights = attn_weights + attention_mask
223
+
224
+ # upcast attention to fp32
225
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
226
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
227
+ attn_output = torch.matmul(attn_weights, value_states)
228
+
229
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
230
+ raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
231
+
232
+ attn_output = attn_output.transpose(1, 2).contiguous()
233
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
234
+
235
+ attn_output = self.out_proj(attn_output)
236
+
237
+ return attn_output, attn_weights
238
+
239
+
240
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
241
+ class SigLipMLP(nn.Module):
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.config = config
245
+ self.activation_fn = ACT2FN[config.hidden_act]
246
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
247
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
248
+
249
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
250
+ hidden_states = self.fc1(hidden_states)
251
+ hidden_states = self.activation_fn(hidden_states)
252
+ hidden_states = self.fc2(hidden_states)
253
+ return hidden_states
254
+
255
+
256
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
257
+ class SigLipEncoderLayer(nn.Module):
258
+ def __init__(self, config: SigLipVisionConfig):
259
+ super().__init__()
260
+ self.embed_dim = config.hidden_size
261
+ self.self_attn = SigLipAttention(config)
262
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
263
+ self.mlp = SigLipMLP(config)
264
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
265
+
266
+ # Ignore copy
267
+ def forward(
268
+ self,
269
+ hidden_states: torch.Tensor,
270
+ attention_mask: torch.Tensor,
271
+ output_attentions: Optional[bool] = False,
272
+ ) -> Tuple[torch.FloatTensor]:
273
+ """
274
+ Args:
275
+ hidden_states (`torch.FloatTensor`):
276
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
277
+ attention_mask (`torch.FloatTensor`):
278
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
279
+ output_attentions (`bool`, *optional*, defaults to `False`):
280
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
281
+ returned tensors for more detail.
282
+ """
283
+ residual = hidden_states
284
+
285
+ hidden_states = self.layer_norm1(hidden_states)
286
+ hidden_states, attn_weights = self.self_attn(
287
+ hidden_states=hidden_states,
288
+ attention_mask=attention_mask,
289
+ output_attentions=output_attentions,
290
+ )
291
+ hidden_states = residual + hidden_states
292
+
293
+ residual = hidden_states
294
+ hidden_states = self.layer_norm2(hidden_states)
295
+ hidden_states = self.mlp(hidden_states)
296
+ hidden_states = residual + hidden_states
297
+
298
+ outputs = (hidden_states,)
299
+
300
+ if output_attentions:
301
+ outputs += (attn_weights,)
302
+
303
+ return outputs
304
+
305
+
306
+ class SigLipPreTrainedModel(PreTrainedModel):
307
+ """
308
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
309
+ models.
310
+ """
311
+
312
+ config_class = SigLipVisionConfig
313
+ base_model_prefix = "siglip"
314
+ supports_gradient_checkpointing = True
315
+
316
+ def _init_weights(self, module):
317
+ """Initialize the weights"""
318
+ pass
319
+
320
+
321
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
322
+ class SigLipEncoder(nn.Module):
323
+ """
324
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
325
+ [`SigLipEncoderLayer`].
326
+
327
+ Args:
328
+ config: SigLipVisionConfig
329
+ """
330
+
331
+ def __init__(self, config: SigLipVisionConfig):
332
+ super().__init__()
333
+ self.config = config
334
+ self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
335
+ self.gradient_checkpointing = False
336
+
337
+ # Ignore copy
338
+ def forward(
339
+ self,
340
+ inputs_embeds,
341
+ attention_mask: Optional[torch.Tensor] = None,
342
+ output_attentions: Optional[bool] = None,
343
+ output_hidden_states: Optional[bool] = None,
344
+ return_dict: Optional[bool] = None,
345
+ ) -> Union[Tuple, BaseModelOutput]:
346
+ r"""
347
+ Args:
348
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
349
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
350
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
351
+ than the model's internal embedding lookup matrix.
352
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
353
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
354
+
355
+ - 1 for tokens that are **not masked**,
356
+ - 0 for tokens that are **masked**.
357
+
358
+ [What are attention masks?](../glossary#attention-mask)
359
+ output_attentions (`bool`, *optional*):
360
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
361
+ returned tensors for more detail.
362
+ output_hidden_states (`bool`, *optional*):
363
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
364
+ for more detail.
365
+ return_dict (`bool`, *optional*):
366
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
367
+ """
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ encoder_states = () if output_hidden_states else None
373
+ all_attentions = () if output_attentions else None
374
+
375
+ hidden_states = inputs_embeds
376
+ for encoder_layer in self.layers:
377
+ if output_hidden_states:
378
+ encoder_states = encoder_states + (hidden_states,)
379
+ if self.gradient_checkpointing and self.training:
380
+ layer_outputs = self._gradient_checkpointing_func(
381
+ encoder_layer.__call__,
382
+ hidden_states,
383
+ attention_mask,
384
+ output_attentions,
385
+ )
386
+ else:
387
+ layer_outputs = encoder_layer(
388
+ hidden_states,
389
+ attention_mask,
390
+ output_attentions=output_attentions,
391
+ )
392
+
393
+ hidden_states = layer_outputs[0]
394
+
395
+ if output_attentions:
396
+ all_attentions = all_attentions + (layer_outputs[1],)
397
+
398
+ if output_hidden_states:
399
+ encoder_states = encoder_states + (hidden_states,)
400
+
401
+ if not return_dict:
402
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
403
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
404
+
405
+
406
+ class SigLipVisionTransformer(nn.Module):
407
+ def __init__(self, config: SigLipVisionConfig):
408
+ super().__init__()
409
+ self.config = config
410
+ embed_dim = config.hidden_size
411
+
412
+ self.embeddings = SigLipVisionEmbeddings(config)
413
+ self.encoder = SigLipEncoder(config)
414
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
415
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
416
+
417
+ def forward(
418
+ self,
419
+ pixel_values,
420
+ output_attentions: Optional[bool] = None,
421
+ output_hidden_states: Optional[bool] = None,
422
+ return_dict: Optional[bool] = None,
423
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
424
+ r"""
425
+ Returns:
426
+
427
+ """
428
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
429
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
430
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
431
+
432
+ hidden_states = self.embeddings(pixel_values)
433
+
434
+ encoder_outputs = self.encoder(
435
+ inputs_embeds=hidden_states,
436
+ output_attentions=output_attentions,
437
+ output_hidden_states=output_hidden_states,
438
+ return_dict=return_dict,
439
+ )
440
+
441
+ last_hidden_state = encoder_outputs[0]
442
+ last_hidden_state = self.post_layernorm(last_hidden_state)
443
+
444
+ pooled_output = self.head(last_hidden_state)
445
+
446
+ if not return_dict:
447
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
448
+
449
+ return BaseModelOutputWithPooling(
450
+ last_hidden_state=last_hidden_state,
451
+ pooler_output=pooled_output,
452
+ hidden_states=encoder_outputs.hidden_states,
453
+ attentions=encoder_outputs.attentions,
454
+ )
455
+
456
+
457
+ class SigLipMultiheadAttentionPoolingHead(nn.Module):
458
+ """Multihead Attention Pooling."""
459
+
460
+ def __init__(self, config: SigLipVisionConfig):
461
+ super().__init__()
462
+
463
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
464
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
465
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
466
+ self.mlp = SigLipMLP(config)
467
+
468
+ def forward(self, hidden_state):
469
+ batch_size = hidden_state.shape[0]
470
+ probe = self.probe.repeat(batch_size, 1, 1)
471
+
472
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
473
+
474
+ residual = hidden_state
475
+ hidden_state = self.layernorm(hidden_state)
476
+ hidden_state = residual + self.mlp(hidden_state)
477
+
478
+ return hidden_state[:, 0]
479
+
480
+
481
+ class SigLipVisionModel(SigLipPreTrainedModel):
482
+ config_class = SigLipVisionConfig
483
+ main_input_name = "pixel_values"
484
+ _no_split_modules = ["SigLipEncoderLayer"]
485
+
486
+ def __init__(self, config: SigLipVisionConfig):
487
+ super().__init__(config)
488
+
489
+ self.vision_model = SigLipVisionTransformer(config)
490
+
491
+ # Initialize weights and apply final processing
492
+ self.post_init()
493
+
494
+ def get_input_embeddings(self) -> nn.Module:
495
+ return self.vision_model.embeddings.patch_embedding
496
+
497
+ def forward(
498
+ self,
499
+ pixel_values,
500
+ output_attentions: Optional[bool] = None,
501
+ output_hidden_states: Optional[bool] = None,
502
+ return_dict: Optional[bool] = None,
503
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
504
+ r"""
505
+ Returns:
506
+
507
+ Examples:
508
+
509
+ ```python
510
+ >>> from PIL import Image
511
+ >>> import requests
512
+ >>> from transformers import AutoProcessor, SigLipVisionModel
513
+
514
+ >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
515
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
516
+
517
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
518
+ >>> image = Image.open(requests.get(url, stream=True).raw)
519
+
520
+ >>> inputs = processor(images=image, return_tensors="pt")
521
+
522
+ >>> outputs = model(**inputs)
523
+ >>> last_hidden_state = outputs.last_hidden_state
524
+ >>> pooled_output = outputs.pooler_output # pooled features
525
+ ```"""
526
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
527
+
528
+ return self.vision_model(
529
+ pixel_values=pixel_values.to(self.device),
530
+ output_attentions=output_attentions,
531
+ output_hidden_states=output_hidden_states,
532
+ return_dict=return_dict,
533
+ )
534
+
535
+
536
+ class SigLipVisionTower(nn.Module):
537
+ def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
538
+ super().__init__()
539
+
540
+ self.is_loaded = False
541
+
542
+ self.config = SigLipVisionConfig()
543
+
544
+ self.vision_tower_name = vision_tower
545
+
546
+ self.image_processor = SigLipImageProcessor()
547
+
548
+ if not delay_load:
549
+ rank0_print(f"Loading vision tower: {vision_tower}")
550
+ self.load_model()
551
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
552
+ # TODO: better detector is needed.
553
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
554
+ self.load_model()
555
+ elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
556
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
557
+ self.load_model()
558
+ else:
559
+ self.cfg_only = self.config
560
+
561
+ def load_model(self, device_map=None):
562
+ if self.is_loaded:
563
+ return
564
+
565
+ self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
566
+
567
+ del self.vision_tower.vision_model.encoder.layers[-1:]
568
+ self.vision_tower.vision_model.head = nn.Identity()
569
+ self.vision_tower.requires_grad_(False)
570
+ self.vision_tower.eval()
571
+
572
+ self.is_loaded = True
573
+
574
+ @torch.no_grad()
575
+ def forward(self, images):
576
+ if type(images) is list:
577
+ image_features = []
578
+ for image in images:
579
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
580
+ image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
581
+ assert image_features.shape[-2] == 729
582
+ image_features.append(image_feature)
583
+ else:
584
+ images=images.to(device=self.device, dtype=self.dtype)
585
+ image_forward_outs = self.vision_tower(images, output_hidden_states=True)
586
+ image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
587
+ assert image_features.shape[-2] == 729
588
+
589
+ return image_features
590
+
591
+ @property
592
+ def dummy_feature(self):
593
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
594
+
595
+ @property
596
+ def dtype(self):
597
+ for p in self.vision_tower.parameters():
598
+ return p.dtype
599
+
600
+ @property
601
+ def device(self):
602
+ for p in self.vision_tower.parameters():
603
+ return p.device
604
+
605
+ @property
606
+ def hidden_size(self):
607
+ return self.config.hidden_size
608
+
609
+ @property
610
+ def num_patches(self):
611
+ return (self.config.image_size // self.config.patch_size) ** 2
612
+
613
+ @property
614
+ def num_patches_per_side(self):
615
+ return self.config.image_size // self.config.patch_size
616
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
617
+
618
+ @property
619
+ def image_size(self):
620
+ return self.config.image_size
llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.39 kB). View file
 
llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
llava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ from .pooler_projector import PoolerProjector
6
+
7
+
8
+ class IdentityMap(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, *args, **kwargs):
13
+ return x
14
+
15
+ @property
16
+ def config(self):
17
+ return {"mm_projector_type": "identity"}
18
+
19
+
20
+ class SimpleResBlock(nn.Module):
21
+ def __init__(self, channels):
22
+ super().__init__()
23
+ self.pre_norm = nn.LayerNorm(channels)
24
+
25
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
26
+
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, "mm_projector_type", "linear")
34
+
35
+ if projector_type == "linear":
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ if projector_type == "pooler":
39
+ return PoolerProjector(config, kwargs["vision_cfg"])
40
+
41
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
42
+ if mlp_gelu_match:
43
+ mlp_depth = int(mlp_gelu_match.group(1))
44
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
45
+ for _ in range(1, mlp_depth):
46
+ modules.append(nn.GELU())
47
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
48
+ return nn.Sequential(*modules)
49
+
50
+ mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
51
+ if mlp_gelu_resnet_match:
52
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
53
+ res_depth = int(mlp_gelu_resnet_match.group(2))
54
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
55
+ for _ in range(1, mlp_depth):
56
+ modules.append(nn.GELU())
57
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
58
+ for _ in range(res_depth):
59
+ modules.append(SimpleResBlock(config.hidden_size))
60
+ return nn.Sequential(*modules)
61
+
62
+ if projector_type == "identity":
63
+ return IdentityMap()
64
+
65
+ raise ValueError(f"Unknown projector type: {projector_type}")
llava/model/multimodal_projector/pooler_projector.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ from transformers.models.clip.modeling_clip import CLIPVisionModel
7
+
8
+
9
+ class PoolerProjector(nn.Module):
10
+ def __init__(self, config, vision_cfg):
11
+ super().__init__()
12
+ self._config = config
13
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
14
+
15
+ self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
16
+
17
+ self.proj = nn.Sequential(
18
+ nn.GELU(),
19
+ nn.Linear(config.hidden_size, config.hidden_size),
20
+ )
21
+
22
+ def forward(self, x, *args, **kwargs):
23
+ height = width = self.hw
24
+ assert height * width == x.shape[1]
25
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
26
+ x = self.conv_pool(x)
27
+ x = x.flatten(2).transpose(1, 2)
28
+ x = self.proj(x)
29
+ return x
30
+
31
+ @property
32
+ def config(self):
33
+ return {"mm_projector_type": "pooler"}
llava/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc ADDED
Binary file (2.46 kB). View file
 
llava/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc ADDED
Binary file (4.85 kB). View file
 
llava/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc ADDED
Binary file (32.7 kB). View file
 
llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
llava/model/multimodal_resampler/builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .masked_drop import MaskedDrop
4
+ from .spatial_pool import SpatialPool
5
+ from .perceiver import PerceiverResampler
6
+ from .qformer import Qformer
7
+
8
+
9
+ class IdentityMap(torch.nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, x, *args, **kwargs):
14
+ return x
15
+
16
+ @property
17
+ def config(self):
18
+ return {"mm_resampler_type": None}
19
+
20
+
21
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
22
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
23
+ if resampler_type == "masked_drop":
24
+ return MaskedDrop(model_args)
25
+ elif resampler_type == "spatial_pool":
26
+ return SpatialPool(model_args, **kwargs)
27
+ elif resampler_type == "perceiver":
28
+ return PerceiverResampler(model_args, **kwargs)
29
+ elif resampler_type == "qformer":
30
+ return Qformer(model_args, **kwargs)
31
+ elif resampler_type is None:
32
+ return IdentityMap()
33
+
34
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
llava/model/multimodal_resampler/masked_drop.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import random
5
+
6
+
7
+ class MaskedDrop(nn.Module):
8
+ def __init__(self, model_args):
9
+ super().__init__()
10
+
11
+ self.mode = model_args.mm_mask_drop_mode
12
+ self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13
+ self.ratio = model_args.mm_mask_drop_ratio
14
+ self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15
+ self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16
+
17
+ def forward(self, image_features, *args, **kwargs):
18
+
19
+ if not self.training:
20
+ return image_features
21
+
22
+ if self.skip_percentage > random.random():
23
+ return image_features
24
+
25
+ masked_features = []
26
+
27
+ for image_feature in image_features:
28
+ num_tokens = image_feature.shape[0]
29
+ if self.mode == "fixed":
30
+ num_keep = int(num_tokens * self.ratio)
31
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
32
+ elif self.mode == "range":
33
+ num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
34
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
35
+ elif self.mode == "cls_only":
36
+ masked_features.append(image_feature[0:1])
37
+ else:
38
+ raise ValueError(f"Unexpected masked drop mode: {self.mode}")
39
+
40
+ if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
41
+ masked_features = torch.stack(masked_features, dim=0)
42
+
43
+ return masked_features
44
+
45
+ @property
46
+ def config(self):
47
+ return {
48
+ "mm_resampler_type": "masked_drop",
49
+ "mm_mask_drop_mode": self.mode,
50
+ "mm_mask_drop_skip_percentage": self.skip_percentage,
51
+ "mm_mask_drop_ratio": self.ratio,
52
+ "mm_mask_drop_ratio_upper": self.ratio_upper,
53
+ "mm_mask_drop_ratio_lower": self.ratio_lower,
54
+ }
55
+
56
+ def random_masking(self, x, len_keep):
57
+ """
58
+ Perform per-sample random masking by per-sample shuffling.
59
+ Per-sample shuffling is done by argsort random noise.
60
+ x: [N, L, D], sequence
61
+ """
62
+ N, L, D = x.shape # batch, length, dim
63
+
64
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
65
+
66
+ # sort noise for each sample
67
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
68
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
69
+
70
+ # keep the first subset
71
+ ids_keep = ids_shuffle[:, :len_keep]
72
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
73
+
74
+ # generate the binary mask: 0 is keep, 1 is remove
75
+ mask = torch.ones([N, L], device=x.device)
76
+ mask[:, :len_keep] = 0
77
+ # unshuffle to get the binary mask
78
+ mask = torch.gather(mask, dim=1, index=ids_restore)
79
+
80
+ return x_masked, mask, ids_restore