batch input
#8
by
bghira
- opened
Will it support batched inputs? One at a time is very slow. It only uses 37G VRAM, which is not all of the 80G we use here.
yes, try sth like this.
import torch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
'THUDM/cogvlm-chat-hf',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to('cuda').eval()
input_sample1 = model.build_conversation_input_ids(
tokenizer,
images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/openai_demo/demo.jpg?raw=true', stream=True).raw).convert('RGB'),],
query='Do you think this is a spring or winter photo?', # Q2
history=[
(
"What's in this image?", # Q1
'The image displays a wooden boardwalk extending through a vibrant green grassy wetland.' # A1
)
],
)
input_sample2 = model.build_conversation_input_ids(
tokenizer,
images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB'),],
query='Describe this image', # Q1
history=[],
)
def recur_move_to(item, tgt, criterion_func):
if criterion_func(item):
device_copy = item.to(tgt)
return device_copy
elif isinstance(item, list):
return [recur_move_to(v, tgt, criterion_func) for v in item]
elif isinstance(item, tuple):
return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
elif isinstance(item, dict):
return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
else:
return item
def collate_fn(features, tokenizer) -> dict:
images = [feature.pop('images') for feature in features]
tokenizer.padding_side = 'left'
padded_features = tokenizer.pad(features)
inputs = {**padded_features, 'images': images}
return inputs
input_batch = collate_fn([input_sample1, input_sample2], tokenizer)
input_batch = recur_move_to(input_batch, 'cuda', lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))
gen_kwargs = {"max_length": 2048, "do_sample": False}
with torch.no_grad():
outputs = model.generate(**input_batch, **gen_kwargs)
outputs = outputs[:, input_batch['input_ids'].shape[1]:]
print(tokenizer.batch_decode(outputs))
chenkq
changed discussion status to
closed