# # For a single image # python app.py image.jpg # # For a single directory # python app.py /path/to/directory # # For multiple directories # python app.py /path/to/directory1 /path/to/directory2 /path/to/directory3 # # With output directory specified # python app.py /path/to/directory1 /path/to/directory2 --output /path/to/output # # With batch size specified # python app.py /path/to/directory1 /path/to/directory2 --bs 8 import torch import torch.amp.autocast_mode import os import sys import logging import warnings import argparse from PIL import Image from pathlib import Path from tqdm import tqdm from torch import nn from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM from typing import List CLIP_PATH = "google/siglip-so400m-patch14-384" VLM_PROMPT = "A descriptive caption for this image:\n" MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-bnb-4bit" CHECKPOINT_PATH = Path("wpkklhc6") warnings.filterwarnings("ignore", category=UserWarning) class ImageAdapter(nn.Module): def __init__(self, input_features: int, output_features: int): super().__init__() self.linear1 = nn.Linear(input_features, output_features) self.activation = nn.GELU() self.linear2 = nn.Linear(output_features, output_features) def forward(self, vision_outputs: torch.Tensor): x = self.linear1(vision_outputs) x = self.activation(x) x = self.linear2(x) return x # Load CLIP print("Loading CLIP 📎") clip_processor = AutoProcessor.from_pretrained(CLIP_PATH) clip_model = AutoModel.from_pretrained(CLIP_PATH) clip_model = clip_model.vision_model clip_model.eval() clip_model.requires_grad_(False) clip_model.to("cuda") # Tokenizer print("Loading tokenizer 🪙") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}" # LLM print("Loading LLM 🤖") logging.getLogger("transformers").setLevel(logging.ERROR) text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16) text_model.eval() # Image Adapter print("Loading image adapter 🖼️") image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size) image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True)) image_adapter.eval() image_adapter.to("cuda") @torch.no_grad() def stream_chat(input_images: List[Image.Image], batch_size=4, pbar=None): torch.cuda.empty_cache() all_captions = [] if not isinstance(input_images, list): input_images = [input_images] for i in range(0, len(input_images), batch_size): batch = input_images[i:i+batch_size] # Preprocess image batch try: images = clip_processor(images=batch, return_tensors='pt', padding=True).pixel_values except ValueError as e: print(f"Error processing image batch: {e}") print("Skipping this batch and continuing...") continue images = images.to('cuda') # Embed image batch with torch.amp.autocast_mode.autocast('cuda', enabled=True): vision_outputs = clip_model(pixel_values=images, output_hidden_states=True) image_features = vision_outputs.hidden_states[-2] embedded_images = image_adapter(image_features) embedded_images = embedded_images.to(dtype=torch.bfloat16) # Embed prompt prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt') prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda')).to(dtype=torch.bfloat16) embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)).to(dtype=torch.bfloat16) # Construct prompts inputs_embeds = torch.cat([ embedded_bos.expand(embedded_images.shape[0], -1, -1), embedded_images, prompt_embeds.expand(embedded_images.shape[0], -1, -1), ], dim=1).to(dtype=torch.bfloat16) input_ids = torch.cat([ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1), torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long), prompt.expand(embedded_images.shape[0], -1), ], dim=1).to('cuda') attention_mask = torch.ones_like(input_ids) generate_ids = text_model.generate( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, ) if pbar: pbar.update(len(batch)) # Trim off the prompt generate_ids = generate_ids[:, input_ids.shape[1]:] for ids in generate_ids: if ids[-1] == tokenizer.eos_token_id: ids = ids[:-1] caption = tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) # Remove any remaining special tokens caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip() all_captions.append(caption) return all_captions def preprocess_image(img): return img.convert('RGBA') def process_image(image_path, output_path, pbar=None): try: with Image.open(image_path) as img: # Convert image to RGB img = img.convert('RGB') caption = stream_chat([img], pbar=pbar)[0] with open(output_path, 'w', encoding='utf-8') as f: f.write(caption) except Exception as e: print(f"Error processing {image_path}: {e}") if pbar: pbar.update(1) return with Image.open(image_path) as img: # Pass the image as a list to stream_chat caption = stream_chat([img], pbar=pbar)[0] # Get the first (and only) caption with open(output_path, 'w', encoding='utf-8') as f: f.write(caption) def process_directory(input_dir, output_dir, batch_size): input_path = Path(input_dir) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp') image_files = [f for f in input_path.iterdir() if f.suffix.lower() in image_extensions] # Create a list to store images that need processing images_to_process = [] # Check which images need processing for file in image_files: output_file = output_path / (file.stem + '.txt') if not output_file.exists(): images_to_process.append(file) else: print(f"Skipping {file.name} - Caption already exists") # Process images in batches with tqdm(total=len(images_to_process), desc="Processing images", unit="image") as pbar: for i in range(0, len(images_to_process), batch_size): batch_files = images_to_process[i:i+batch_size] batch_images = [] for f in batch_files: try: img = Image.open(f).convert('RGB') batch_images.append(img) except Exception as e: print(f"Error opening {f}: {e}") continue if batch_images: captions = stream_chat(batch_images, batch_size, pbar) for file, caption in zip(batch_files, captions): output_file = output_path / (file.stem + '.txt') with open(output_file, 'w', encoding='utf-8') as f: f.write(caption) # Close the image files for img in batch_images: img.close() def parse_arguments(): parser = argparse.ArgumentParser(description="Process images and generate captions.") parser.add_argument("input", nargs='+', help="Input image file or directory (or multiple directories)") parser.add_argument("--output", help="Output directory (optional)") parser.add_argument("--bs", type=int, default=4, help="Batch size (default: 4)") return parser.parse_args() def is_image_file(file_path): image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp') return Path(file_path).suffix.lower() in image_extensions # Main execution if __name__ == "__main__": args = parse_arguments() input_paths = [Path(input_path) for input_path in args.input] batch_size = args.bs for input_path in input_paths: if input_path.is_file() and is_image_file(input_path): # Single file processing output_path = input_path.with_suffix('.txt') print(f"Processing single image 🎞️: {input_path.name}") with tqdm(total=1, desc="Processing image", unit="image") as pbar: process_image(input_path, output_path, pbar) print(f"Output saved to {output_path}") elif input_path.is_dir(): # Directory processing output_path = Path(args.output) if args.output else input_path print(f"Processing directory 📁: {input_path}") print(f"Output directory 📦: {output_path}") print(f"Batch size 🗄️: {batch_size}") process_directory(input_path, output_path, batch_size) else: print(f"Invalid input: {input_path}") print("Skipping...") if not input_paths: print("Usage:") print("For single image: python app.py [image_file] [--bs batch_size]") print("For directory (same input/output): python app.py [directory] [--bs batch_size]") print("For directory (separate input/output): python app.py [directory] --output [output_directory] [--bs batch_size]") print("For multiple directories: python app.py [directory1] [directory2] ... [--output output_directory] [--bs batch_size]") sys.exit(1)