Wi-zz commited on
Commit
2908489
β€’
1 Parent(s): 18053d6

Create app-multi-alpha.py

Browse files
Files changed (1) hide show
  1. app-multi-alpha.py +210 -0
app-multi-alpha.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.amp.autocast_mode
3
+ import torch.distributed as dist
4
+ import torch.multiprocessing as mp
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+ import os
7
+ import sys
8
+ import logging
9
+ import warnings
10
+ import argparse
11
+ from PIL import Image
12
+ from pathlib import Path
13
+ from tqdm import tqdm
14
+ from torch import nn
15
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
16
+ from typing import List, Union
17
+
18
+ # Constants
19
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
20
+ VLM_PROMPT = "A descriptive caption for this image:\n"
21
+ MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-bnb-4bit"
22
+ CHECKPOINT_PATH = Path("wpkklhc6")
23
+ IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
24
+
25
+ warnings.filterwarnings("ignore", category=UserWarning)
26
+ logging.getLogger("transformers").setLevel(logging.ERROR)
27
+
28
+ def setup(rank, world_size):
29
+ os.environ['MASTER_ADDR'] = 'localhost'
30
+ os.environ['MASTER_PORT'] = '12355'
31
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
32
+
33
+ def cleanup():
34
+ dist.destroy_process_group()
35
+
36
+ class ImageAdapter(nn.Module):
37
+ def __init__(self, input_features: int, output_features: int):
38
+ super().__init__()
39
+ self.linear1 = nn.Linear(input_features, output_features)
40
+ self.activation = nn.GELU()
41
+ self.linear2 = nn.Linear(output_features, output_features)
42
+
43
+ def forward(self, vision_outputs: torch.Tensor):
44
+ return self.linear2(self.activation(self.linear1(vision_outputs)))
45
+
46
+ def load_models(rank):
47
+ print(f"Loading CLIP πŸ“Ž on GPU {rank}")
48
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
49
+ clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model.eval().requires_grad_(False).to(rank)
50
+
51
+ print(f"Loading tokenizer πŸͺ™ on GPU {rank}")
52
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
53
+ assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
54
+
55
+ print(f"Loading LLM πŸ€– on GPU {rank}")
56
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map={"": rank}, torch_dtype=torch.bfloat16).eval()
57
+
58
+ print(f"Loading image adapter πŸ–ΌοΈ on GPU {rank}")
59
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
60
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location=f"cuda:{rank}", weights_only=True))
61
+ image_adapter.eval().to(rank)
62
+
63
+ return clip_processor, clip_model, tokenizer, text_model, image_adapter
64
+
65
+ @torch.no_grad()
66
+ def stream_chat(input_images: List[Image.Image], batch_size: int, pbar: tqdm, models: tuple, rank: int) -> List[str]:
67
+ clip_processor, clip_model, tokenizer, text_model, image_adapter = models
68
+ torch.cuda.empty_cache()
69
+ all_captions = []
70
+
71
+ for i in range(0, len(input_images), batch_size):
72
+ batch = input_images[i:i+batch_size]
73
+
74
+ try:
75
+ images = clip_processor(images=batch, return_tensors='pt', padding=True).pixel_values.to(rank)
76
+ except ValueError as e:
77
+ print(f"Error processing image batch: {e}")
78
+ print("Skipping this batch and continuing...")
79
+ continue
80
+
81
+ with torch.amp.autocast_mode.autocast(rank, enabled=True):
82
+ vision_outputs = clip_model(pixel_values=images, output_hidden_states=True)
83
+ image_features = vision_outputs.hidden_states[-2]
84
+ embedded_images = image_adapter(image_features).to(dtype=torch.bfloat16)
85
+
86
+ prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt')
87
+ prompt_embeds = text_model.model.embed_tokens(prompt.to(rank)).to(dtype=torch.bfloat16)
88
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=rank, dtype=torch.int64)).to(dtype=torch.bfloat16)
89
+
90
+ inputs_embeds = torch.cat([
91
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
92
+ embedded_images,
93
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
94
+ ], dim=1).to(dtype=torch.bfloat16)
95
+
96
+ input_ids = torch.cat([
97
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1),
98
+ torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long),
99
+ prompt.expand(embedded_images.shape[0], -1),
100
+ ], dim=1).to(rank)
101
+
102
+ attention_mask = torch.ones_like(input_ids)
103
+
104
+ generate_ids = text_model.generate(
105
+ input_ids=input_ids,
106
+ inputs_embeds=inputs_embeds,
107
+ attention_mask=attention_mask,
108
+ max_new_tokens=300,
109
+ do_sample=True,
110
+ top_k=10,
111
+ temperature=0.5,
112
+ )
113
+
114
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
115
+
116
+ for ids in generate_ids:
117
+ caption = tokenizer.decode(ids[:-1] if ids[-1] == tokenizer.eos_token_id else ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
118
+ caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
119
+ all_captions.append(caption)
120
+
121
+ if pbar and rank == 0:
122
+ pbar.update(len(batch))
123
+
124
+ return all_captions
125
+
126
+ def process_directory(rank, world_size, input_dir: Path, output_dir: Path, batch_size: int, models: tuple):
127
+ output_dir.mkdir(parents=True, exist_ok=True)
128
+ image_files = [f for f in input_dir.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS]
129
+ images_to_process = [f for f in image_files if not (output_dir / f"{f.stem}.txt").exists()]
130
+
131
+ if not images_to_process:
132
+ if rank == 0:
133
+ print("No new images to process.")
134
+ return
135
+
136
+ # Distribute images across GPUs
137
+ images_per_gpu = len(images_to_process) // world_size
138
+ start_idx = rank * images_per_gpu
139
+ end_idx = start_idx + images_per_gpu if rank < world_size - 1 else len(images_to_process)
140
+ gpu_images = images_to_process[start_idx:end_idx]
141
+
142
+ if rank == 0:
143
+ pbar = tqdm(total=len(images_to_process), desc="Processing images", unit="image")
144
+ else:
145
+ pbar = None
146
+
147
+ for i in range(0, len(gpu_images), batch_size):
148
+ batch_files = gpu_images[i:i+batch_size]
149
+ batch_images = [Image.open(f).convert('RGB') for f in batch_files]
150
+
151
+ captions = stream_chat(batch_images, batch_size, pbar, models, rank)
152
+
153
+ for file, caption in zip(batch_files, captions):
154
+ with open(output_dir / f"{file.stem}.txt", 'w', encoding='utf-8') as f:
155
+ f.write(caption)
156
+
157
+ for img in batch_images:
158
+ img.close()
159
+
160
+ if rank == 0:
161
+ pbar.close()
162
+
163
+ def parse_arguments():
164
+ parser = argparse.ArgumentParser(description="Process images and generate captions.")
165
+ parser.add_argument("input", nargs='+', help="Input image file or directory (or multiple directories)")
166
+ parser.add_argument("--output", help="Output directory (optional)")
167
+ parser.add_argument("--bs", type=int, default=4, help="Batch size (default: 4)")
168
+ return parser.parse_args()
169
+
170
+ def run(rank, world_size, args):
171
+ setup(rank, world_size)
172
+
173
+ input_paths = [Path(input_path) for input_path in args.input]
174
+ batch_size = args.bs
175
+ models = load_models(rank)
176
+
177
+ for input_path in input_paths:
178
+ if input_path.is_file() and input_path.suffix.lower() in IMAGE_EXTENSIONS:
179
+ if rank == 0:
180
+ output_path = input_path.with_suffix('.txt')
181
+ print(f"Processing single image 🎞️: {input_path.name}")
182
+ with tqdm(total=1, desc="Processing image", unit="image") as pbar:
183
+ captions = stream_chat([Image.open(input_path).convert('RGB')], 1, pbar, models, rank)
184
+ with open(output_path, 'w', encoding='utf-8') as f:
185
+ f.write(captions[0])
186
+ print(f"Output saved to {output_path}")
187
+ elif input_path.is_dir():
188
+ output_path = Path(args.output) if args.output else input_path
189
+ if rank == 0:
190
+ print(f"Processing directory πŸ“: {input_path}")
191
+ print(f"Output directory πŸ“¦: {output_path}")
192
+ print(f"Batch size πŸ—„οΈ: {batch_size}")
193
+ process_directory(rank, world_size, input_path, output_path, batch_size, models)
194
+ else:
195
+ if rank == 0:
196
+ print(f"Invalid input: {input_path}")
197
+ print("Skipping...")
198
+
199
+ cleanup()
200
+
201
+ def main():
202
+ args = parse_arguments()
203
+ world_size = torch.cuda.device_count()
204
+ if world_size > 1:
205
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
206
+ else:
207
+ run(0, 1, args)
208
+
209
+ if __name__ == "__main__":
210
+ main()