diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..03ff76df5665b3fa05b3be5a1699b1e0dd298d41 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__ +*.pyc +*.egg-info +dist + +output +output_dir +*.pth +*.log +weights \ No newline at end of file diff --git a/README.md b/README.md index e0405e1925d27b053e6581abc975e9c885d3335e..2c5b4b0ef87dfe708a2d0290369750f47f2c50be 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,100 @@ --- -title: LLaMA Adapter V2 +title: OneLLM emoji: 🚀 colorFrom: red colorTo: indigo sdk: gradio -sdk_version: 3.23.0 +sdk_version: 4.7.1 app_file: app.py pinned: false --- -### LLaMA-Adapter -The official demo for LLaMA-Adapter V2. -Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details. +# OneLLM: One Framework to Align All Modalities with Language +[[Project Page](https://onellm.csuhan.com)] [[Paper](#)] [[Web Demo](https://huggingface.co/spaces/csuhan/OneLLM)] + +Authors: [Jiaming Han](), [Kaixiong Gong](), [Yiyuan Zhang](), [Jiaqi Wang](), [Kaipeng Zhang](), [Dahua Lin](), [Yu Qiao](), [Peng Gao](), [Xiangyu Yue](). + +## News + +- **2023.12.01** Release model weights and inference code. + +## Contents + +- [Install](#install) +- [Models](#models) +- [Demo](#demo) + + + + + +### TODO + +- [ ] Data +- [ ] Evaluation +- [ ] Training + +### Install + +1. Clone the repo into a local folder. + +```bash +git clone https://github.com/csuhan/OneLLM + +cd OneLLM +``` + +2. Install packages. + +```bash +conda create -n onellm python=3.9 -y +conda activate onellm + +pip install -r requirements.txt + +# install pointnet +cd lib/pointnet2 +python setup.py install +``` + +3. Install Apex. (Optional) + +```bash +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ +``` + +### Models + +We provide a preview model at: [csuhan/OneLLM-7B](https://huggingface.co/csuhan/OneLLM-7B). + +### Demo + +**Huggingface Demo:** [csuhan/OneLLM](https://huggingface.co/spaces/csuhan/OneLLM). + +**Local Demo:** Assume you have downloaded the weights to ${WEIGHTS_DIR}. Then run the following command to start a gradio demo locally. + +```bash +python demos/multi_turn_mm.py --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth +``` + + + + + +## Citation + +``` +@article{han2023onellm, + title={OneLLM: One Framework to Align All Modalities with Language}, + author={Han, Jiaming and Gong, Kaixiong and Zhang, Yiyuan and Wang, Jiaqi and Zhang, Kaipeng and Lin, Dahua and Qiao, Yu and Gao, Peng and Yue, Xiangyu}, + journal={arXiv preprint arXiv:xxxx}, + year={2023} +} +``` + +## Acknowledgement + +[LLaMA](https://github.com/facebookresearch/llama), [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [ChatBridge](https://github.com/joez17/ChatBridge) diff --git a/app.py b/app.py index 26c88289cdf2fe061461952331735abe6fa46172..a180cda697755716b352d8fa6456204db8aac801 100644 --- a/app.py +++ b/app.py @@ -1,277 +1,272 @@ -import json -import os -import glob import sys -import time -from pathlib import Path -from typing import Tuple +import os + +import argparse +import multiprocessing as mp +import numpy as np +from typing import List, Optional -from huggingface_hub import hf_hub_download -from PIL import Image -import gradio as gr import torch -from fairscale.nn.model_parallel.initialize import initialize_model_parallel - -from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel - -os.environ['CUDA_LAUNCH_BLOCKING'] = '1' - -PROMPT_DICT = { - "prompt_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" - ), - "prompt_no_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:" - ), -} - - -def setup_model_parallel() -> Tuple[int, int]: - os.environ['RANK'] = '0' - os.environ['WORLD_SIZE'] = '1' - os.environ['MP'] = '1' - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '2223' - local_rank = int(os.environ.get("LOCAL_RANK", -1)) - world_size = int(os.environ.get("WORLD_SIZE", -1)) - - torch.distributed.init_process_group("nccl") - initialize_model_parallel(world_size) - torch.cuda.set_device(local_rank) - - # seed must be the same in all processes - torch.manual_seed(1) - return local_rank, world_size - - -def load( - ckpt0_path: str, - ckpt1_path: str, - param_path: str, - tokenizer_path: str, - instruct_adapter_path: str, - caption_adapter_path: str, - local_rank: int, - world_size: int, - max_seq_len: int, - max_batch_size: int, -) -> LLaMA: - start_time = time.time() - print("Loading") - instruct_adapter_checkpoint = torch.load( - instruct_adapter_path, map_location="cpu") - caption_adapter_checkpoint = torch.load( - caption_adapter_path, map_location="cpu") - with open(param_path, "r") as f: - params = json.loads(f.read()) - - model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params - ) - model_args.adapter_layer = int( - instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len) - model_args.cap_adapter_layer = int( - caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len) - - tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words - torch.set_default_tensor_type(torch.cuda.HalfTensor) - model = Transformer(model_args) - - # To reduce memory usuage - ckpt0 = torch.load(ckpt0_path, map_location='cuda') - model.load_state_dict(ckpt0, strict=False) - del ckpt0 - torch.cuda.empty_cache() - - ckpt1 = torch.load(ckpt1_path, map_location='cuda') - model.load_state_dict(ckpt1, strict=False) - del ckpt1 - torch.cuda.empty_cache() - - vision_model = VisionModel(model_args) - - torch.set_default_tensor_type(torch.FloatTensor) - model.load_state_dict(instruct_adapter_checkpoint, strict=False) - model.load_state_dict(caption_adapter_checkpoint, strict=False) - vision_model.load_state_dict(caption_adapter_checkpoint, strict=False) - - generator = LLaMA(model, tokenizer, vision_model) - print(f"Loaded in {time.time() - start_time:.2f} seconds") - return generator - - -def instruct_generate( - instruct: str, - input: str = 'none', - max_gen_len=512, - temperature: float = 0.1, - top_p: float = 0.75, -): - if input == 'none': - prompt = PROMPT_DICT['prompt_no_input'].format_map( - {'instruction': instruct, 'input': ''}) - else: - prompt = PROMPT_DICT['prompt_input'].format_map( - {'instruction': instruct, 'input': input}) - - results = generator.generate( - [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p - ) - result = results[0].strip() - print(result) - return result - - -def caption_generate( - img: str, - max_gen_len=512, - temperature: float = 0.1, - top_p: float = 0.75, -): - imgs = [Image.open(img).convert('RGB')] - prompts = ["Generate caption of this image :",] * len(imgs) - - results = generator.generate( - prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p - ) - result = results[0].strip() - print(result) - return result - - -def download_llama_adapter(instruct_adapter_path, caption_adapter_path): - if not os.path.exists(instruct_adapter_path): - os.system( - f"wget -q -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth") - - if not os.path.exists(caption_adapter_path): - os.system( - f"wget -q -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth") - - -# ckpt_path = "/data1/llma/7B/consolidated.00.pth" -# param_path = "/data1/llma/7B/params.json" -# tokenizer_path = "/data1/llma/tokenizer.model" -ckpt0_path = hf_hub_download( - repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth") -ckpt1_path = hf_hub_download( - repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth") -param_path = hf_hub_download( - repo_id="nyanko7/LLaMA-7B", filename="params.json") -tokenizer_path = hf_hub_download( - repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model") -instruct_adapter_path = "llama_adapter_len10_layer30_release.pth" -caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth" -max_seq_len = 512 -max_batch_size = 1 - -# download models -# download_llama_adapter(instruct_adapter_path, caption_adapter_path) - -local_rank, world_size = setup_model_parallel() -if local_rank > 0: - sys.stdout = open(os.devnull, "w") - -generator = load( - ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size -) - - -def create_instruct_demo(): - with gr.Blocks() as instruct_demo: - with gr.Row(): - with gr.Column(): - instruction = gr.Textbox(lines=2, label="Instruction") - input = gr.Textbox( - lines=2, label="Context input", placeholder='none') - max_len = gr.Slider(minimum=1, maximum=512, - value=128, label="Max length") - with gr.Accordion(label='Advanced options', open=False): - temp = gr.Slider(minimum=0, maximum=1, - value=0.1, label="Temperature") - top_p = gr.Slider(minimum=0, maximum=1, - value=0.75, label="Top p") - - run_botton = gr.Button("Run") - - with gr.Column(): - outputs = gr.Textbox(lines=10, label="Output") - - inputs = [instruction, input, max_len, temp, top_p] - - examples = [ - "Tell me about alpacas.", - "Write a Python program that prints the first 10 Fibonacci numbers.", - "Write a conversation between the sun and pluto.", - "Write a theory to explain why cat never existed", - ] - examples = [ - [x, "none", 128, 0.1, 0.75] - for x in examples] - - gr.Examples( - examples=examples, - inputs=inputs, - outputs=outputs, - fn=instruct_generate, - cache_examples=os.getenv('SYSTEM') == 'spaces' - ) - run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs) - return instruct_demo +import torch.distributed as dist +from fairscale.nn.model_parallel import initialize as fs_init -def create_caption_demo(): - with gr.Blocks() as instruct_demo: - with gr.Row(): - with gr.Column(): - img = gr.Image(label='Input', type='filepath') - max_len = gr.Slider(minimum=1, maximum=512, - value=64, label="Max length") - with gr.Accordion(label='Advanced options', open=False): - temp = gr.Slider(minimum=0, maximum=1, - value=0.1, label="Temperature") - top_p = gr.Slider(minimum=0, maximum=1, - value=0.75, label="Top p") - - run_botton = gr.Button("Run") - - with gr.Column(): - outputs = gr.Textbox(lines=10, label="Output") - - inputs = [img, max_len, temp, top_p] - - examples = glob.glob("caption_demo/*.jpg") - examples = [ - [x, 64, 0.1, 0.75] - for x in examples] - - gr.Examples( - examples=examples, - inputs=inputs, - outputs=outputs, - fn=caption_generate, - cache_examples=os.getenv('SYSTEM') == 'spaces' - ) - run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs) - return instruct_demo +import gradio as gr +from util.misc import setup_for_distributed +from util.misc import default_tensor_type +from model.meta import MetaModel +from data.conversation_lib import conv_templates, SeparatorStyle +from PIL import Image +import torchvision.transforms as transforms +from data.fintune_dataset import make_audio_features +from data import video_utils +from dataclasses import dataclass +from huggingface_hub import hf_hub_download +T_random_resized_crop = transforms.Compose([ + transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3, + antialias=None), # 3 is bicubic + transforms.ToTensor(), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) + + +def load_audio(audio_path): + fbank = make_audio_features(audio_path, mel_bins=128) + fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024] + return fbank + +def load_video(video_path): + video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5) + return video_feats[:, :, 0] + + +def model_worker( + rank: int, args: argparse.Namespace, barrier: mp.Barrier, + request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None, +) -> None: + """ + The worker function that manipulates the GPU to run the inference. + Exact n_gpu workers are started, with each one operating on a separate GPU. + + Args: + rank (int): Distributed rank of the worker. + args (argparse.Namespace): All command line arguments. + barrier (multiprocessing.Barrier): A barrier used to delay the start + of Web UI to be after the start of the model. + """ + + world_size = len(args.gpu_ids) + gpu_id = args.gpu_ids[rank] + dist.init_process_group( + backend="nccl", rank=rank, world_size=world_size, + init_method=f"tcp://{args.master_addr}:{args.master_port}", + ) + print(f"| distributed init on worker {rank}/{world_size}. " + f"using gpu: {gpu_id}") + fs_init.initialize_model_parallel(world_size) + torch.cuda.set_device(gpu_id) -description = """ -# LLaMA-Adapter🚀 -The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**. -Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details. -""" + torch.manual_seed(1) + np.random.seed(1) + + # set the print behavior. + setup_for_distributed(rank == 0) + + target_dtype = { + "bf16": torch.bfloat16, + "fp16": torch.float16 + }[args.dtype] + with default_tensor_type(dtype=target_dtype, device="cuda"): + model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path) + print("Loading pretrained weights ...") + checkpoint = torch.load(args.pretrained_path, map_location='cpu') + msg = model.load_state_dict(checkpoint, strict=False) + print("load result:\n", msg) + model.cuda() + model.eval() + print(f"Model = {str(model)}") + + barrier.wait() + + while True: + img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get() + if 'image' in modality and img_path is not None: + image = Image.open(img_path).convert('RGB') + inputs = T_random_resized_crop(image) + elif 'video' in modality and video_path is not None: + inputs = load_video(video_path) + elif 'audio' in modality and audio_path is not None: + inputs = load_audio(audio_path) + else: + inputs = None + + if inputs is not None: + inputs = inputs[None].cuda().to(target_dtype) + + conv = conv_templates["v1"].copy() + for user, bot in chatbot: + conv.append_message(conv.roles[0], user) + conv.append_message(conv.roles[1], bot) + + with torch.cuda.amp.autocast(dtype=target_dtype): + print(conv.get_prompt()) + for stream_response in model.stream_generate( + conv.get_prompt(), inputs, + max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, + modal = modality + ): + conv_sep = ( + conv.sep + if conv.sep_style == SeparatorStyle.SINGLE + else conv.sep2 + ) + end_pos = stream_response["text"].find(conv_sep) + if end_pos != -1: + stream_response["text"] = ( + stream_response['text'][:end_pos].rstrip() + "\n" + ) + stream_response["end_of_content"] = True + + # keep a few characters if not end_of_content to avoid sending + # part of conv_sep before all of it is generated. + if not stream_response["end_of_content"]: + if len(stream_response["text"]) < len(conv_sep): + continue + stream_response["text"] = ( + stream_response["text"][:-len(conv_sep)] + ) + + if response_queue is not None: + response_queue.put(stream_response) + + if stream_response["end_of_content"]: + break + + +def gradio_worker( + request_queues: List[mp.Queue], response_queue: mp.Queue, + args: argparse.Namespace, barrier: mp.Barrier, +) -> None: + """ + The gradio worker is responsible for displaying the WebUI and relay the + requests to model workers. It should be launched only once. + + Args: + request_queues (List[mp.Queue]): A list of request queues (one for + each model worker). + args (argparse.Namespace): All command line arguments. + barrier (multiprocessing.Barrier): A barrier used to delay the start + of Web UI to be after the start of the model. + """ + + def show_user_input(msg, chatbot): + return "", chatbot + [[msg, None]] + + def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality): + for queue in request_queues: + queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality)) + while True: + content_piece = response_queue.get() + chatbot[-1][1] = content_piece["text"] + yield chatbot + if content_piece["end_of_content"]: + break + + def undo(chatbot): + if len(chatbot) > 0: + chatbot = chatbot[:-1] + return chatbot + + def clear(): + chatbot = [] + msg = "" + return chatbot, msg + + CSS =""" + .contain { display: flex; flex-direction: column; } + #component-0 { height: 100%; } + #chatbot { flex-grow: 1; overflow: auto;} + """ + with gr.Blocks(css=CSS) as demo: + gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language") + with gr.Row(equal_height=True): + with gr.Column(scale=1): + img_path = gr.Image(label='Image Input', type='filepath') + video_path = gr.Video(label='Video Input') + audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload']) + modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities') + + with gr.Column(scale=2): + chatbot = gr.Chatbot(elem_id="chatbot") + msg = gr.Textbox() -with gr.Blocks(css='style.css') as demo: - gr.Markdown(description) - with gr.TabItem("Instruction-Following"): - create_instruct_demo() - with gr.TabItem("Image Captioning"): - create_caption_demo() + with gr.Row(): + submit_button = gr.Button("Submit", variant="primary") + undo_button = gr.Button("Undo") + clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality]) + with gr.Row(): + max_gen_len = gr.Slider( + minimum=1, maximum=args.model_max_seq_len // 2, + value=args.model_max_seq_len // 2, interactive=True, + label="Single-turn max response length", + ) + gen_t = gr.Slider( + minimum=0, maximum=1, value=0.1, interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0, maximum=1, value=0.75, interactive=True, + label="Top-p", + ) + msg.submit( + show_user_input, [msg, chatbot], [msg, chatbot], + ).then( + stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, + ) + submit_button.click( + show_user_input, [msg, chatbot], [msg, chatbot], + ).then( + stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, + ) + undo_button.click(undo, chatbot, chatbot) + # img_path.change(clear, [], [chatbot, msg]) + barrier.wait() + demo.queue(api_open=True).launch(share=True, max_threads=1) + + +@dataclass +class DemoConfig: + gpu_ids = [0] + tokenizer_path = "config/llama2/tokenizer.model" + llama_type = "onellm" + llama_config = "config/llama2/7B.json" + model_max_seq_len = 2048 + # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth" + pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth") + master_port = 23861 + master_addr = "127.0.0.1" + dtype = "fp16" + +if __name__ == "__main__": + args = DemoConfig() + # using the default "fork" method messes up some imported libs (e.g., + # pandas) + mp.set_start_method("spawn") + + # setup the queues and start the model workers + request_queues = [] + response_queue = mp.Queue() + worker_processes = [] + barrier = mp.Barrier(len(args.gpu_ids) + 1) + for rank, gpu_id in enumerate(args.gpu_ids): + request_queue = mp.Queue() + rank_response_queue = response_queue if rank == 0 else None + process = mp.Process( + target=model_worker, + args=(rank, args, barrier, request_queue, rank_response_queue), + ) + process.start() + worker_processes.append(process) + request_queues.append(request_queue) -demo.queue(api_open=True, concurrency_count=1).launch() + gradio_worker(request_queues, response_queue, args, barrier) diff --git a/config/llama2/7B.json b/config/llama2/7B.json new file mode 100644 index 0000000000000000000000000000000000000000..6523f76675b50e9cf3a57d1fb135189abcffe1c7 --- /dev/null +++ b/config/llama2/7B.json @@ -0,0 +1 @@ +{"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1} diff --git a/config/llama2/tokenizer.model b/config/llama2/tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..6c00c742ce03c627d6cd5b795984876fa49fa899 --- /dev/null +++ b/config/llama2/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 +size 499723 diff --git a/data/__pycache__/conversation_lib.cpython-310.pyc b/data/__pycache__/conversation_lib.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7104daf11059185efd723e40160c7debf191003b Binary files /dev/null and b/data/__pycache__/conversation_lib.cpython-310.pyc differ diff --git a/data/__pycache__/conversation_lib.cpython-39.pyc b/data/__pycache__/conversation_lib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdca3a64c8523a2b4439b3e0c894c64b2020f486 Binary files /dev/null and b/data/__pycache__/conversation_lib.cpython-39.pyc differ diff --git a/data/__pycache__/fintune_dataset.cpython-310.pyc b/data/__pycache__/fintune_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e4e08c37f7d6e25b19727d2847d34fc1fdd0c8e Binary files /dev/null and b/data/__pycache__/fintune_dataset.cpython-310.pyc differ diff --git a/data/__pycache__/fintune_dataset.cpython-39.pyc b/data/__pycache__/fintune_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45989953b86dec0a2084a59127bc1028067b8640 Binary files /dev/null and b/data/__pycache__/fintune_dataset.cpython-39.pyc differ diff --git a/data/__pycache__/imu_utils.cpython-310.pyc b/data/__pycache__/imu_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae8e9e22e039ecc388eb3193589fa83b5c3847b Binary files /dev/null and b/data/__pycache__/imu_utils.cpython-310.pyc differ diff --git a/data/__pycache__/imu_utils.cpython-39.pyc b/data/__pycache__/imu_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ddf01f39b937cad76df5c6201b4f62b441694d Binary files /dev/null and b/data/__pycache__/imu_utils.cpython-39.pyc differ diff --git a/data/__pycache__/video_utils.cpython-310.pyc b/data/__pycache__/video_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c007e0ea4f4c05fe867e0bc31077d4c6bc0fe79 Binary files /dev/null and b/data/__pycache__/video_utils.cpython-310.pyc differ diff --git a/data/__pycache__/video_utils.cpython-39.pyc b/data/__pycache__/video_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81b5c7d05231e85a9f4552385921740940514e39 Binary files /dev/null and b/data/__pycache__/video_utils.cpython-39.pyc differ diff --git a/data/conversation_lib.py b/data/conversation_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..783fe0eb8f9dd425ec6c285e820f755d2e955a3b --- /dev/null +++ b/data/conversation_lib.py @@ -0,0 +1,369 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + '\n\n' + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + '\n' + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + if self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + from PIL import Image + msg, image, image_process_mode = msg + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image) + elif image_process_mode == "Crop": + pass + elif image_process_mode == "Resize": + image = image.resize((224, 224)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + msg, image, image_process_mode = msg + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + # image = image.resize((224, 224)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = msg.replace('', img_str) + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Give three tips for staying healthy."), + ("Assistant", + "Sure, here are three tips for staying healthy:\n" + "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " + "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " + "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " + "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " + "activities at least two days per week.\n" + "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " + "vegetables, whole grains, lean proteins, and healthy fats can help support " + "your overall health. Try to limit your intake of processed and high-sugar foods, " + "and aim to drink plenty of water throughout the day.\n" + "3. Get enough sleep: Getting enough quality sleep is essential for your physical " + "and mental health. Adults should aim for seven to nine hours of sleep per night. " + "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " + "help improve the quality of your sleep.") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_v1_2 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=(), + + # ( + # ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + # ("Assistant", + # "Renewable energy sources are those that can be replenished naturally in a relatively " + # "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + # "Non-renewable energy sources, on the other hand, are finite and will eventually be " + # "depleted, such as coal, oil, and natural gas. Here are some key differences between " + # "renewable and non-renewable energy sources:\n" + # "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + # "energy sources are finite and will eventually run out.\n" + # "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + # "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + # "and other negative effects.\n" + # "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + # "have lower operational costs than non-renewable sources.\n" + # "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + # "locations than non-renewable sources.\n" + # "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + # "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + # "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + # "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + # ) + offset = 2, + sep_style = SeparatorStyle.SINGLE, + sep = "###", + ) + +conv_vicuna_v1_1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_mpt = Conversation( + system="""<|im_start|>system +- You are a helpful language and vision assistant. +- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +- You should follow the instructions carefully and explain your answers in detail.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_mpt_text = Conversation( + system="""<|im_start|>system +- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_bair_v1 = Conversation( + system="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +simple_conv = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!"), + ("Assistant", "Hi there! How can I help you today?") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_multimodal = Conversation( + system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!"), + ("Assistant", "Hi there! How can I help you today?\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_mpt_multimodal = Conversation( + system="""<|im_start|>system +- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. +- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +- You should follow the instructions carefully and explain your answers in detail.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +simple_conv_legacy = Conversation( + system="You are LLaVA, a large language model trained by UW Madison WAIV Lab." + "You are designed to assist human with a variety of tasks using natural language." + "Follow the instructions carefully.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!\n\n### Response:"), + ("Assistant", "Hi there! How can I help you today?\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v1 = Conversation( + system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +default_conversation = conv_v1_2 +conv_templates = { + "default": conv_v1_2, + "simple": simple_conv, + "simple_legacy": simple_conv_legacy, + "multimodal": simple_conv_multimodal, + "mpt_multimodal": simple_conv_mpt_multimodal, + "llava_v1": conv_llava_v1, + + # fastchat + "v1": conv_v1_2, + "bair_v1": conv_bair_v1, + "vicuna_v1_1": conv_vicuna_v1_1, + "mpt": conv_mpt, + "mpt_text": conv_mpt_text, +} + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/data/fintune_dataset.py b/data/fintune_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f787139d702fbc46ef5ca8189ade8f89a9c7df0 --- /dev/null +++ b/data/fintune_dataset.py @@ -0,0 +1,449 @@ +import warnings + +import torch +import yaml +from torch.utils.data import Dataset +from PIL import Image +import json +from model.tokenizer import Tokenizer +import os +import torchvision.transforms as transforms +import random +import torchvision.transforms.functional as F +import torchaudio +from . import conversation_lib + +import numpy as np +from . import video_utils +from .imu_utils import get_imu_frames + + +IGNORE_INDEX = -100 + +DEFAULT_IMAGE_TOKEN = "" +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +T_random_resized_crop = transforms.Compose([ + transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC, + antialias=None), # 3 is bicubic + transforms.ToTensor(), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) + + +# image transform +transform_img_train = transforms.Compose([ + transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=( + 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic + transforms.ToTensor(), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) + + +class PairRandomResizedCrop(transforms.RandomResizedCrop): + def forward(self, imgs): + i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) + return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs] + + +class PairToTensor(transforms.ToTensor): + def __call__(self, pics): + return [F.to_tensor(pic) for pic in pics] + + +class PairNormalize(transforms.Normalize): + def forward(self, tensors): + return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors] + + +transform_pairimg_train = transforms.Compose([ + PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=( + 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic + PairToTensor(), + PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) + + +def pc_norm(pc): + """ pc: NxC, return NxC """ + xyz = pc[:, :3] + other_feature = pc[:, 3:] + + centroid = torch.mean(xyz, dim=0) + xyz = xyz - centroid + m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1))) + xyz = xyz / m + + pc = torch.cat((xyz, other_feature), dim=1) + return pc + + +def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False): + waveform, sr = torchaudio.load(wav_name) + # assert sr == 16000, 'input audio sampling rate must be 16kHz' + if sr != 16000: + trans = torchaudio.transforms.Resample(sr, 16000) + waveform = trans(waveform) + + waveform = waveform - waveform.mean() + + fbank = torchaudio.compliance.kaldi.fbank( + waveform, htk_compat=True, sample_frequency=16000, use_energy=False, + window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10) + + n_frames = fbank.shape[0] + + p = target_length - n_frames + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[0:target_length, :] + + if aug: + freqm = torchaudio.transforms.FrequencyMasking(48) + timem = torchaudio.transforms.TimeMasking(192) + fbank = torch.transpose(fbank, 0, 1) + fbank = fbank.unsqueeze(0) + fbank = freqm(fbank) + fbank = timem(fbank) + fbank = fbank.squeeze(0) + fbank = torch.transpose(fbank, 0, 1) + + fbank = (fbank - (-4.2677393)) / (4.5689974 * 2) + return fbank + + +class ConversationGenerator: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.header = f"{conversation_lib.default_conversation.system}\n\n" + self._probe_tokenizer_style() + + def _probe_tokenizer_style(self): + """ + Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token, + while some others will merge the space into the next word, forming a token representing " darling". + Knowing which style the tokenizer takes is necessary for correct ground-truth label masking. + + """ + probe = "Probe am I" + sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe, + bos=False, eos=False) + sentence2 = self.tokenizer.encode(probe, + bos=False, eos=False) + if sentence1[-len(sentence2):] == sentence2: + self.space_before_to_predict = False + else: + sentence3 = self.tokenizer.encode(" " + probe, + bos=False, eos=False) + assert sentence1[-len(sentence3):] == sentence3 + self.space_before_to_predict = True + + def add_speaker_and_signal(self, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = self.header + + to_predict_list = [] + + for sentence in source: + from_str = sentence["from"] + if from_str.lower() in ["human"]: + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() in ["gpt", "assistant"]: + from_str = conversation_lib.default_conversation.roles[1] + else: + raise ValueError(f"unknown dialog role: {from_str.lower()}") + + value = sentence["value"] + if DEFAULT_IMAGE_TOKEN in value: + value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip() + + sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL + + if from_str == conversation_lib.default_conversation.roles[1]: + to_predict_value = value + END_SIGNAL + "###" + if self.space_before_to_predict: + to_predict_value = " " + to_predict_value + to_predict_list.append(to_predict_value) + + if get_conversation: + conversation = conversation + sentence_value + + conversation = conversation + BEGIN_SIGNAL + return conversation, to_predict_list + + +DATASETS = dict( + image=[ + dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'), + dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'), + dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'), + ], + audio=[ + dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'), + dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'), + dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'), + ], + video=[ + dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'), + dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'), + dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'), + dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'), + dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'), + dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'), + dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'), + dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'), + ], + point=[ + dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'), + ], + rgbd=[ + dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'), + ], + rgbn=[ + dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'), + ], + imu=[ + dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'), + ], + fmri=[ + dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'), + ], +) +IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/" + + +class FinetuneDialogDataset(Dataset): + def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None): + if isinstance(dataset, str): + dataset = [dataset] + + self.dataset = dataset + + group_ann = {} + for d in dataset: + for meta in DATASETS[d]: + meta_path, meta_type = meta['path'], meta['type'] + meta_ext = os.path.splitext(meta_path)[-1] + if meta_ext == ".json": + with open(meta_path) as f: + meta_l = json.load(f) + # add data_type + # this is a temp solution + new_meta_l = [] + for l in meta_l: + l['data_type'] = meta_type + new_meta_l.append(l) + meta_l = new_meta_l + elif meta_ext == ".jsonl": + meta_l = [] + with open(meta_path) as f: + for i, line in enumerate(f): + try: + meta_l.append(json.loads(line)) + except json.decoder.JSONDecodeError as e: + print( + f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True) + raise e + else: + raise NotImplementedError( + f"Unknown meta file extension: \"{meta_ext}\". " + f"Currently, .json, .jsonl are supported. " + "If you are using a supported format, please set the file extension so that the proper parsing " + "routine can be called." + ) + if meta_type not in group_ann: + group_ann[meta_type] = [] + print(f"{meta_path}, type {meta_type}: len {len(meta_l)}") + group_ann[meta_type] += meta_l + + # sort group_ann for higher efficiency (items in one global batch with similar length) + for meta_type, meta_l in group_ann.items(): + meta_l.sort(key=lambda data_item: sum( + [len(_['value']) for _ in data_item['conversations']])) + + self.group_ann = group_ann + self.ann = sum(list(self.group_ann.values()), start=[]) + + self.group_indices = {} + start_pos = 0 + for meta_type, meta_l in self.group_ann.items(): + self.group_indices[meta_type] = list( + range(start_pos, start_pos + len(meta_l))) + start_pos = start_pos + len(meta_l) + + print(f"total length: {len(self)}") + self.transform = transform + print(f"transform:\n{self.transform}") + self.max_words = max_words + self.image_words = image_words + self.tokenizer = Tokenizer(model_path=tokenizer_path) + self.conversation_generator = ConversationGenerator(self.tokenizer) + + self.load_funcs = dict( + image=self.load_image, + audio=self.load_audio, + video=self.load_video, + point=self.load_point, + rgbd=self.load_rgbx, + rgbn=self.load_rgbx, + imu=self.load_imu, + fmri=self.load_fmri + ) + + def __len__(self): + return len(self.ann) + + def load_image(self, data): + filename = data['image'] + image = Image.open(filename).convert('RGB') + image = self.transform(image) + return image + + def load_audio(self, data): + audio_path = data['image'] + fbank = make_audio_features(audio_path, mel_bins=128) + fbank = fbank.transpose(0, 1)[None] # [1, 128, 1024] + return fbank + + def load_video(self, data): + video_path = data['image'] + video_feats = video_utils.load_and_transform_video_data( + video_path, video_path, clip_duration=1, clips_per_video=5) + return video_feats[:, :, 0] + + def load_point(self, data): + point_path = data['image'] + point_feat = torch.load(point_path, map_location='cpu') + point_feat = point_feat.transpose(0, 1) + return point_feat + + def load_rgbx(self, data): + image_path = data['image'] + x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image'] + image = Image.open(image_path).convert('RGB') + x_image = Image.open(x_image_path).convert('RGB') + x_image = x_image.resize(image.size[-2:]) + + image, x_image = transform_pairimg_train([image, x_image]) + # [2, 3, H, W] + image = torch.stack([image, x_image], dim=0) + return image + + def load_fmri(self, data): + fmri_path = data['image'] + data = np.load(fmri_path) + data = data.mean(axis=0) + data = torch.tensor(data[None]) + return data + + def load_imu(self, data_dict): + uid = data_dict["video_uid"] + w_s = data_dict["window_start"] + w_e = data_dict["window_end"] + + imu_data = get_imu_frames( + IMU_PATH, uid, + video_start_sec=w_s, + video_end_sec=w_e, + ) + if imu_data is None: + raise ValueError + return imu_data['signal'] + + def __getitem__(self, index, expect_type=None): + if expect_type is None: + data_item = self.ann[index] + else: + # in case we want get data from specific data_type + data_item = self.group_ann[expect_type][index] + + data_type = data_item['data_type'] + if data_type != 'text': + if data_type in self.load_funcs: + try: + image = self.load_funcs[data_type](data_item) + if image == None: + raise ValueError('Data is None') + except: + print('Error', data_item) + rand_idx = random.randint( + 0, len(self.group_ann[data_type])) + return self.__getitem__(rand_idx, expect_type=data_type) + else: + raise ValueError(f'Does not support {data_type}') + else: + image = None + # warnings.warn("pure black image for examples without image") + # image = torch.zeros(3, 224, 224) + + source = data_item["conversations"] + conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal( + source) + if len(to_predict_values) == 0: + warnings.warn( + f"see dialog data with nothing to predict, data: {data_item}") + return self[index-1] + + tokenzed_conversation = self.tokenizer.encode( + conversation, bos=True, eos=True) + labels = [IGNORE_INDEX for _ in tokenzed_conversation] + + check_pos = 0 + for value in to_predict_values: + tokenized_value = self.tokenizer.encode( + value, bos=False, eos=False) + value_pos = find_sublist( + tokenzed_conversation[check_pos:], tokenized_value) + check_pos + if value_pos == -1: + print( + "a sentence mismatches the corresponding piece in the conversation") + return self[index-1] + labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value + assert labels[value_pos:value_pos+len( + tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)] + check_pos = value_pos+len(tokenized_value) + + input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64) + labels = torch.tensor(labels, dtype=torch.int64) + + if image is not None: + max_words = self.max_words - self.image_words + else: + max_words = self.max_words + padding = max_words - input2.shape[0] + if padding > 0: + input2 = torch.cat( + (input2, torch.zeros(padding, dtype=torch.int64) - 1)) + labels = torch.cat( + (labels, torch.zeros(padding, dtype=torch.int64) - 1)) + elif padding < 0: + input2 = input2[:max_words] + labels = labels[:max_words] + + input2_mask = input2.ge(0) + label_mask = labels.ge(0) + input2[~input2_mask] = 0 + labels[~label_mask] = 0 + input2_mask = input2_mask.float() + label_mask = label_mask.float() + if image is None: + return input2, labels, data_item['data_type'] + else: + return input2, labels, image, data_item['data_type'] + + def groups(self): + return list(self.group_indices.values()) + + +def find_sublist(a: list, b: list): + len_a, len_b = len(a), len(b) + for i in range(len_a - len_b + 1): + if a[i:i+len_b] == b: + return i + return -1 diff --git a/data/imu_utils.py b/data/imu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..010563d67e603bd7ca5589672058380a79ee93d9 --- /dev/null +++ b/data/imu_utils.py @@ -0,0 +1,257 @@ +import string +import numpy as np +import matplotlib.animation as animation +from matplotlib import pyplot as plt +import json +from collections import defaultdict +from bisect import bisect_left +import os +import torch +import torchaudio +torchaudio.set_audio_backend("sox_io") + + +def load_json(json_path: str): + """ + Load a json file + """ + with open(json_path, "r", encoding="utf-8") as f_name: + data = json.load(f_name) + return data + + +def check_window_signal(info_t, w_s, w_e): + length = w_e - w_s + frame_offset = int(w_s * info_t.sample_rate) + num_frames = int(length * info_t.sample_rate) + if frame_offset + num_frames > int(info_t.num_frames): + return False + else: + return True + + +def index_narrations(ann_path): + narration_raw = load_json(ann_path) + + narration_dict = defaultdict(list) + summary_dict = defaultdict(list) + avg_len = [] + for v_id, narr in narration_raw.items(): + narr_list = [] + summ_list = [] + if "narration_pass_1" in narr: + narr_list += narr["narration_pass_1"]["narrations"] + summ_list += narr["narration_pass_1"]["summaries"] + if "narration_pass_2" in narr: + narr_list += narr["narration_pass_2"]["narrations"] + summ_list += narr["narration_pass_2"]["summaries"] + + if len(narr_list) > 0: + narration_dict[v_id] = [ + ( + float(n_t["timestamp_sec"]), + n_t["narration_text"], + n_t["annotation_uid"], + n_t["timestamp_frame"], + ) + for n_t in narr_list + ] + avg_len.append(len(narration_dict[v_id])) + else: + narration_dict[v_id] = [] + if len(summ_list) > 0: + summary_dict[v_id] = [ + ( + float(s_t["start_sec"]), + float(s_t["end_sec"]), + s_t["summary_text"], + ) + for s_t in summ_list + ] + else: + summary_dict[v_id] = [] + # print(f"Number of Videos with narration {len(narration_dict)}") + # print(f"Avg. narration length {np.mean(avg_len)}") + # print(f"Number of Videos with summaries {len(summary_dict)}") + return narration_dict, summary_dict + + +def get_signal_info(signal_fn: str): + return torchaudio.info(signal_fn) + + +def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float): + """ + Given a signal track return the frames between video_start_sec and video_end_sec + """ + info_t = get_signal_info(signal_fn) + + length = video_end_sec - video_start_sec + aframes, _ = torchaudio.load( + signal_fn, + normalize=True, + frame_offset=int(video_start_sec * info_t.sample_rate), + num_frames=int(length * info_t.sample_rate), + ) + return {"signal": aframes, "meta": info_t} + + +def tosec(value): + return value / 1000 + + +def toms(value): + return value * 1000 + + +def delta(first_num: float, second_num: float): + """Compute the absolute value of the difference of two numbers""" + return abs(first_num - second_num) + + +def padIMU(signal, duration_sec): + """ + Pad the signal if necessary + """ + expected_elements = round(duration_sec) * 200 + + if signal.shape[0] > expected_elements: + signal = signal[:expected_elements, :] + elif signal.shape[0] < expected_elements: + padding = expected_elements - signal.shape[0] + padded_zeros = np.zeros((padding, 6)) + signal = np.concatenate([signal, padded_zeros], 0) + # signal = signal[:expected_elements, :] + return signal + + +def resample( + signals: np.ndarray, + timestamps: np.ndarray, + original_sample_rate: int, + resample_rate: int, +): + """ + Resamples data to new sample rate + """ + signals = torch.as_tensor(signals) + timestamps = torch.from_numpy(timestamps).unsqueeze(-1) + signals = torchaudio.functional.resample( + waveform=signals.data.T, + orig_freq=original_sample_rate, + new_freq=resample_rate, + ).T.numpy() + + nsamples = len(signals) + + period = 1 / resample_rate + + # timestamps are expected to be shape (N, 1) + initital_seconds = timestamps[0] / 1e3 + + ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds + + timestamps = (ntimes * 1e3).squeeze().numpy() + return signals, timestamps + + +def resampleIMU(signal, timestamps): + sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps))))) + # resample all to 200hz + if sampling_rate != 200: + signal, timestamps = resample(signal, timestamps, sampling_rate, 200) + return signal, timestamps + + +def get_imu_frames( + imu_path, + uid: str, + video_start_sec: float, + video_end_sec: float, +): + """ + Given a IMU signal return the frames between video_start_sec and video_end_sec + """ + signal = np.load(os.path.join(imu_path, f"{uid}.npy")) + signal = signal.transpose() + timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy")) + + if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]: + return None + + start_id = bisect_left(timestamps, toms(video_start_sec)) + end_id = bisect_left(timestamps, toms(video_end_sec)) + + # make sure the retrieved window interval are correct by a max of 1 sec margin + if ( + delta(video_start_sec, tosec(timestamps[start_id])) > 4 + or delta(video_end_sec, tosec(timestamps[end_id])) > 4 + ): + return None + + # get the window + if start_id == end_id: + start_id -= 1 + end_id += 1 + signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id] + + if len(signal) < 10 or len(timestamps) < 10: + return None + # resample the signal at 200hz if necessary + signal, timestamps = resampleIMU(signal, timestamps) + + # pad the signal if necessary + signal = padIMU(signal, video_end_sec - video_start_sec) + + sample_dict = { + "timestamp": timestamps, + "signal": torch.tensor(signal.T), + "sampling_rate": 200, + } + + return sample_dict + + +def display_animation(frames, title, save_path_gif): + fig, ax = plt.subplots() + frames = [[ax.imshow(frames[i])] for i in range(len(frames))] + plt.title(title) + ani = animation.ArtistAnimation(fig, frames) + ani.save(save_path_gif, writer="imagemagick") + plt.close() + + +def display_animation_imu(frames, imu, title, save_path_gif): + fig, (ax1, ax2, ax3) = plt.subplots(3, 1) + ax1.set_title(title) + ax2.set_title("Acc.") + ax3.set_title("Gyro.") + frames = [[ax1.imshow(frames[i])] for i in range(len(frames))] + ani = animation.ArtistAnimation(fig, frames) + + ax2.plot(imu[0].cpu().numpy(), color="red") + ax2.plot(imu[1].cpu().numpy(), color="blue") + ax2.plot(imu[2].cpu().numpy(), color="green") + ax3.plot(imu[3].cpu().numpy(), color="red") + ax3.plot(imu[4].cpu().numpy(), color="blue") + ax3.plot(imu[5].cpu().numpy(), color="green") + plt.tight_layout() + ani.save(save_path_gif, writer="imagemagick") + plt.close() + + +def filter_narration(narration_text: str) -> bool: + if "#c" in narration_text.lower(): + return True + return False + + +def clean_narration_text(narration_text: str) -> str: + return ( + narration_text.replace("#C C ", "") + .replace("#C", "") + .replace("#unsure", "something") + .strip() + .strip(string.punctuation) + .lower()[:128] + ) diff --git a/data/video_utils.py b/data/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43ac03067e50c8570d422a057b9d9efb18e8775b --- /dev/null +++ b/data/video_utils.py @@ -0,0 +1,204 @@ +import math +import torch +import torch.nn as nn +from pytorchvideo import transforms as pv_transforms +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler +from pytorchvideo.data.encoded_video import EncodedVideo +from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord +from torchvision import transforms +from torchvision.transforms._transforms_video import NormalizeVideo + + +def get_clip_timepoints(clip_sampler, duration): + # Read out all clips in this video + all_clips_timepoints = [] + is_last_clip = False + end = 0.0 + while not is_last_clip: + start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) + all_clips_timepoints.append((start, end)) + return all_clips_timepoints + + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Perform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to perform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +class SpatialCrop(nn.Module): + """ + Convert the video into 3 smaller clips spatially. Must be used after the + temporal crops to get spatial crops, and should be used with + -2 in the spatial crop at the slowfast augmentation stage (so full + frames are passed in here). Will return a larger list with the + 3x spatial crops as well. + """ + + def __init__(self, crop_size: int = 224, num_crops: int = 3): + super().__init__() + self.crop_size = crop_size + if num_crops == 3: + self.crops_to_ext = [0, 1, 2] + self.flipped_crops_to_ext = [] + elif num_crops == 1: + self.crops_to_ext = [1] + self.flipped_crops_to_ext = [] + else: + raise NotImplementedError("Nothing else supported yet") + + def forward(self, videos): + """ + Args: + videos: A list of C, T, H, W videos. + Returns: + videos: A list with 3x the number of elements. Each video converted + to C, T, H', W' by spatial cropping. + """ + assert isinstance(videos, list), "Must be a list of videos after temporal crops" + assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" + res = [] + for video in videos: + for spatial_idx in self.crops_to_ext: + res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) + if not self.flipped_crops_to_ext: + continue + flipped_video = transforms.functional.hflip(video) + for spatial_idx in self.flipped_crops_to_ext: + res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) + return res + + +def load_and_transform_video_data( + video_file, + video_path, + clip_duration=2, + clips_per_video=5, + sample_rate=16000, + with_audio=False +): + video_transform = transforms.Compose( + [ + pv_transforms.ShortSideScale(224), + NormalizeVideo( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + clip_sampler = ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) + + if isinstance(video_file, str): + video = EncodedVideo.from_path( + video_file, + decoder="decord", + decode_audio=with_audio, + # **{"sample_rate": sample_rate}, + ) + else: + video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate) + + all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) + + all_video = [] + for clip_timepoints in all_clips_timepoints: + # Read the clip, get frames + clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) + if clip is None: + raise ValueError("No clip found") + video_clip = frame_sampler(clip["video"]) + video_clip = video_clip / 255.0 # since this is float, need 0-1 + + all_video.append(video_clip) + + all_video = [video_transform(clip) for clip in all_video] + all_video = SpatialCrop(224, num_crops=3)(all_video) + + all_video = torch.stack(all_video, dim=0) + + if not with_audio: + return all_video + else: + return all_video, clip['audio'] + +if __name__ == '__main__': + video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4" + video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True) + import pdb;pdb.set_trace() \ No newline at end of file diff --git a/demos/multi_turn_mm.py b/demos/multi_turn_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..6f354e6c68d0a09df50c87a1a53f110a4fe7321a --- /dev/null +++ b/demos/multi_turn_mm.py @@ -0,0 +1,300 @@ +import sys +import os +sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0]) + +import argparse +import multiprocessing as mp +import numpy as np +from typing import List, Optional + +import torch +import torch.distributed as dist + +from fairscale.nn.model_parallel import initialize as fs_init + +import gradio as gr +from util.misc import setup_for_distributed +from util.misc import default_tensor_type +from model.meta import MetaModel +from data.conversation_lib import conv_templates, SeparatorStyle +from PIL import Image +import torchvision.transforms as transforms +from data.fintune_dataset import make_audio_features +from data import video_utils + + +T_random_resized_crop = transforms.Compose([ + transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3, + antialias=None), # 3 is bicubic + transforms.ToTensor(), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) + + +def load_audio(audio_path): + fbank = make_audio_features(audio_path, mel_bins=128) + fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024] + return fbank + +def load_video(video_path): + video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5) + return video_feats[:, :, 0] + + +def model_worker( + rank: int, args: argparse.Namespace, barrier: mp.Barrier, + request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None, +) -> None: + """ + The worker function that manipulates the GPU to run the inference. + Exact n_gpu workers are started, with each one operating on a separate GPU. + + Args: + rank (int): Distributed rank of the worker. + args (argparse.Namespace): All command line arguments. + barrier (multiprocessing.Barrier): A barrier used to delay the start + of Web UI to be after the start of the model. + """ + + world_size = len(args.gpu_ids) + gpu_id = args.gpu_ids[rank] + dist.init_process_group( + backend="nccl", rank=rank, world_size=world_size, + init_method=f"tcp://{args.master_addr}:{args.master_port}", + ) + print(f"| distributed init on worker {rank}/{world_size}. " + f"using gpu: {gpu_id}") + fs_init.initialize_model_parallel(world_size) + torch.cuda.set_device(gpu_id) + + torch.manual_seed(1) + np.random.seed(1) + + # set the print behavior. + setup_for_distributed(rank == 0) + + target_dtype = { + "bf16": torch.bfloat16, + "fp16": torch.float16 + }[args.dtype] + with default_tensor_type(dtype=target_dtype, device="cuda"): + model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path) + print("Loading pretrained weights ...") + checkpoint = torch.load(args.pretrained_path, map_location='cpu') + msg = model.load_state_dict(checkpoint, strict=False) + print("load result:\n", msg) + model.cuda() + model.eval() + print(f"Model = {str(model)}") + + barrier.wait() + + while True: + img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get() + if 'image' in modality and img_path is not None: + image = Image.open(img_path).convert('RGB') + inputs = T_random_resized_crop(image) + elif 'video' in modality and video_path is not None: + inputs = load_video(video_path) + elif 'audio' in modality and audio_path is not None: + inputs = load_audio(audio_path) + else: + inputs = None + + if inputs is not None: + inputs = inputs[None].cuda().to(target_dtype) + + conv = conv_templates["v1"].copy() + for user, bot in chatbot: + conv.append_message(conv.roles[0], user) + conv.append_message(conv.roles[1], bot) + + with torch.cuda.amp.autocast(dtype=target_dtype): + print(conv.get_prompt()) + for stream_response in model.stream_generate( + conv.get_prompt(), inputs, + max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, + modal = modality + ): + conv_sep = ( + conv.sep + if conv.sep_style == SeparatorStyle.SINGLE + else conv.sep2 + ) + end_pos = stream_response["text"].find(conv_sep) + if end_pos != -1: + stream_response["text"] = ( + stream_response['text'][:end_pos].rstrip() + "\n" + ) + stream_response["end_of_content"] = True + + # keep a few characters if not end_of_content to avoid sending + # part of conv_sep before all of it is generated. + if not stream_response["end_of_content"]: + if len(stream_response["text"]) < len(conv_sep): + continue + stream_response["text"] = ( + stream_response["text"][:-len(conv_sep)] + ) + + if response_queue is not None: + response_queue.put(stream_response) + + if stream_response["end_of_content"]: + break + + +def gradio_worker( + request_queues: List[mp.Queue], response_queue: mp.Queue, + args: argparse.Namespace, barrier: mp.Barrier, +) -> None: + """ + The gradio worker is responsible for displaying the WebUI and relay the + requests to model workers. It should be launched only once. + + Args: + request_queues (List[mp.Queue]): A list of request queues (one for + each model worker). + args (argparse.Namespace): All command line arguments. + barrier (multiprocessing.Barrier): A barrier used to delay the start + of Web UI to be after the start of the model. + """ + + def show_user_input(msg, chatbot): + return "", chatbot + [[msg, None]] + + def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality): + for queue in request_queues: + queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality)) + while True: + content_piece = response_queue.get() + chatbot[-1][1] = content_piece["text"] + yield chatbot + if content_piece["end_of_content"]: + break + + def undo(chatbot): + if len(chatbot) > 0: + chatbot = chatbot[:-1] + return chatbot + + def clear(): + chatbot = [] + msg = "" + return chatbot, msg + + CSS =""" + .contain { display: flex; flex-direction: column; } + #component-0 { height: 100%; } + #chatbot { flex-grow: 1; overflow: auto;} + """ + with gr.Blocks(css=CSS) as demo: + gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language") + with gr.Row(equal_height=True): + with gr.Column(scale=1): + img_path = gr.Image(label='Image Input', type='filepath') + video_path = gr.Video(label='Video Input') + audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload']) + modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities') + + with gr.Column(scale=2): + chatbot = gr.Chatbot(elem_id="chatbot") + msg = gr.Textbox() + + with gr.Row(): + submit_button = gr.Button("Submit", variant="primary") + undo_button = gr.Button("Undo") + clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality]) + with gr.Row(): + max_gen_len = gr.Slider( + minimum=1, maximum=args.model_max_seq_len // 2, + value=args.model_max_seq_len // 2, interactive=True, + label="Single-turn max response length", + ) + gen_t = gr.Slider( + minimum=0, maximum=1, value=0.1, interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0, maximum=1, value=0.75, interactive=True, + label="Top-p", + ) + msg.submit( + show_user_input, [msg, chatbot], [msg, chatbot], + ).then( + stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, + ) + submit_button.click( + show_user_input, [msg, chatbot], [msg, chatbot], + ).then( + stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, + ) + undo_button.click(undo, chatbot, chatbot) + # img_path.change(clear, [], [chatbot, msg]) + barrier.wait() + demo.queue(api_open=True).launch(share=True, max_threads=1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Chat Demo") + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--gpu_ids", type=int, nargs="+", + help="A list of space-separated gpu ids to run the model on. " + "The model will span across GPUs in tensor-parallel mode." + ) + parser.add_argument( + "--tokenizer_path", type=str, + help="Path to the tokenizer.model file provided along with the LLaMA " + "model." + ) + parser.add_argument( + "--llama_type", default="onellm", type=str, metavar="MODEL", + help="LLaMA model type." + ) + parser.add_argument( + "--llama_config", type=str, required=True, + help="Path to the llama model config json." + ) + parser.add_argument( + "--model_max_seq_len", type=int, default=2048, + help="Max sequence length accepted by the pretrained model." + ) + parser.add_argument( + "--pretrained_path", type=str, required=True, + help="Path to the llama model checkpoints. A list of checkpoints is " + "supported and will be merged from left to right.") + parser.add_argument( + "--master_port", type=int, default=23862, + help="A port used by the PyTorch distributed module to initialize." + ) + parser.add_argument( + "--master_addr", type=str, default="127.0.0.1", + help="An address used by the PyTorch distributed module to initialize." + ) + parser.add_argument( + "--dtype", type=str, choices=["fp16", "bf16"], default="fp16", + help="The dtype used for model weights and inference." + ) + args = parser.parse_args() + + # using the default "fork" method messes up some imported libs (e.g., + # pandas) + mp.set_start_method("spawn") + + # setup the queues and start the model workers + request_queues = [] + response_queue = mp.Queue() + worker_processes = [] + barrier = mp.Barrier(len(args.gpu_ids) + 1) + for rank, gpu_id in enumerate(args.gpu_ids): + request_queue = mp.Queue() + rank_response_queue = response_queue if rank == 0 else None + process = mp.Process( + target=model_worker, + args=(rank, args, barrier, request_queue, rank_response_queue), + ) + process.start() + worker_processes.append(process) + request_queues.append(request_queue) + + gradio_worker(request_queues, response_queue, args, barrier) diff --git a/lib/__pycache__/point_utils.cpython-310.pyc b/lib/__pycache__/point_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b52bf4169d4d84233f3178c745896d1fa395824f Binary files /dev/null and b/lib/__pycache__/point_utils.cpython-310.pyc differ diff --git a/lib/point_utils.py b/lib/point_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..834733a64b540a141bfce09f6d0fae3154f89997 --- /dev/null +++ b/lib/point_utils.py @@ -0,0 +1,191 @@ +import torch +import torch.nn as nn +from torch.autograd import Function +import pointnet2_cuda + +class KNN(nn.Module): + def __init__(self, neighbors, transpose_mode=True): + super(KNN, self).__init__() + self.neighbors = neighbors + + @torch.no_grad() + def forward(self, support, query): + """ + Args: + support ([tensor]): [B, N, C] + query ([tensor]): [B, M, C] + Returns: + [int]: neighbor idx. [B, M, K] + """ + dist = torch.cdist(support, query) + k_dist = dist.topk(k=self.neighbors, dim=1, largest=False) + return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int() + + +class GroupingOperation(Function): + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) tensor of features to group + :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with + :return: + output: (B, C, npoint, nsample) tensor + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, nfeatures, nsample = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device) + + pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) + + ctx.for_backwards = (idx, N) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + """ + :param ctx: + :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward + :return: + grad_features: (B, C, N) gradient of the features + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True) + grad_out_data = grad_out.data.contiguous() + pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) + return grad_features, None + +grouping_operation = GroupingOperation.apply + + +class KNNGroup(nn.Module): + def __init__(self, nsample: int, + relative_xyz=True, + normalize_dp=False, + return_only_idx=False, + **kwargs + ): + """[summary] + + Args: + nsample (int): maximum number of features to gather in the ball + use_xyz (bool, optional): concate xyz. Defaults to True. + ret_grouped_xyz (bool, optional): [description]. Defaults to False. + normalize_dp (bool, optional): [description]. Defaults to False. + """ + super().__init__() + self.nsample = nsample + self.knn = KNN(nsample, transpose_mode=True) + self.relative_xyz = relative_xyz + self.normalize_dp = normalize_dp + self.return_only_idx = return_only_idx + + def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None): + """ + :param query_xyz: (B, N, 3) xyz coordinates of the features + :param support_xyz: (B, npoint, 3) centroids + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, 3 + C, npoint, nsample) + """ + _, idx = self.knn(support_xyz, query_xyz) + if self.return_only_idx: + return idx + idx = idx.int() + xyz_trans = support_xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + if self.relative_xyz: + grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1) # relative position + if self.normalize_dp: + grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1) + if features is not None: + grouped_features = grouping_operation(features, idx) + return grouped_xyz, grouped_features + else: + return grouped_xyz, None + + +class FurthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + :param ctx: + :param xyz: (B, N, 3) where N > npoint + :param npoint: int, number of features in the sampled set + :return: + output: (B, npoint) tensor containing the set (idx) + """ + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + # output = torch.cuda.IntTensor(B, npoint, device=xyz.device) + # temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10) + output = torch.cuda.IntTensor(B, npoint) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + + pointnet2_cuda.furthest_point_sampling_wrapper( + B, N, npoint, xyz, temp, output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + +furthest_point_sample = FurthestPointSampling.apply + + +class PointPatchEmbed(nn.Module): + + def __init__(self, + sample_ratio=0.0625, + sample_number=1024, + group_size=32, + in_channels=6, + channels=1024, + kernel_size=1, + stride=1, + normalize_dp=False, + relative_xyz=True, + ): + super().__init__() + self.sample_ratio = sample_ratio + self.sample_number = sample_number + self.group_size = group_size + + self.sample_fn = furthest_point_sample + self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp) + + self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride) + + + def forward(self, x): + # coordinates + p = x[:, :, 3:].contiguous() + + B, N, _ = p.shape[:3] + # idx = self.sample_fn(p, int(N * self.sample_ratio)).long() + idx = self.sample_fn(p, self.sample_number).long() + center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3)) + # query neighbors. + _, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32] + + # [B, 6, 1024] -> [B, channels, 1024, 1] + fj = self.conv1(fj).max(dim=-1, keepdim=True)[0] + + return fj + + +if __name__ == '__main__': + model = PointPatchEmbed(channels=256).cuda() + input = torch.rand(4, 16384, 6).cuda() + ou = model(input) + import pdb;pdb.set_trace() \ No newline at end of file diff --git a/lib/pointnet2/pointnet2_modules.py b/lib/pointnet2/pointnet2_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5f125ce5075c738897e5f6a78c71123d0e3e44a2 --- /dev/null +++ b/lib/pointnet2/pointnet2_modules.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import pointnet2_utils +from . import pytorch_utils as pt_utils +from typing import List + + +class _PointnetSAModuleBase(nn.Module): + + def __init__(self): + super().__init__() + self.npoint = None + self.groupers = None + self.mlps = None + self.pool_method = 'max_pool' + + def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): + """ + :param xyz: (B, N, 3) tensor of the xyz coordinates of the features + :param features: (B, N, C) tensor of the descriptors of the the features + :param new_xyz: + :return: + new_xyz: (B, npoint, 3) tensor of the new features' xyz + new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors + """ + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + if new_xyz is None: + new_xyz = pointnet2_utils.gather_operation( + xyz_flipped, + pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + for i in range(len(self.groupers)): + new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) + + new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) + if self.pool_method == 'max_pool': + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + elif self.pool_method == 'avg_pool': + new_features = F.avg_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + else: + raise NotImplementedError + + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): + """Pointnet set abstraction layer with multiscale grouping""" + + def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, + use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param npoint: int + :param radii: list of float, list of radii to group with + :param nsamples: list of int, number of samples in each ball query + :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__() + + assert len(radii) == len(nsamples) == len(mlps) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) + if npoint is not None else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) + self.pool_method = pool_method + + +class PointnetSAModule(PointnetSAModuleMSG): + """Pointnet set abstraction layer""" + + def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, + bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param mlp: list of int, spec of the pointnet before the global max_pool + :param npoint: int, number of features + :param radius: float, radius of ball + :param nsample: int, number of samples in the ball query + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__( + mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, + pool_method=pool_method, instance_norm=instance_norm + ) + + +class PointnetFPModule(nn.Module): + r"""Propigates the features of one set to another""" + + def __init__(self, *, mlp: List[int], bn: bool = True): + """ + :param mlp: list of int + :param bn: whether to use batchnorm + """ + super().__init__() + self.mlp = pt_utils.SharedMLP(mlp, bn=bn) + + def forward( + self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor + ) -> torch.Tensor: + """ + :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features + :param known: (B, m, 3) tensor of the xyz positions of the known features + :param unknow_feats: (B, C1, n) tensor of the features to be propigated to + :param known_feats: (B, C2, m) tensor of features to be propigated + :return: + new_features: (B, mlp[-1], n) tensor of the features of the unknown features + """ + if known is not None: + dist, idx = pointnet2_utils.three_nn(unknown, known) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + + interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) + else: + interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) + + if unknow_feats is not None: + new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) + else: + new_features = interpolated_feats + + new_features = new_features.unsqueeze(-1) + new_features = self.mlp(new_features) + + return new_features.squeeze(-1) + + +if __name__ == "__main__": + pass diff --git a/lib/pointnet2/pointnet2_utils.py b/lib/pointnet2/pointnet2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e814102d8feb5e443e64a736e7733818e0a24685 --- /dev/null +++ b/lib/pointnet2/pointnet2_utils.py @@ -0,0 +1,290 @@ +import torch +from torch.autograd import Variable +from torch.autograd import Function +import torch.nn as nn +from typing import Tuple + +import pointnet2_cuda as pointnet2 + + +class FurthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + :param ctx: + :param xyz: (B, N, 3) where N > npoint + :param npoint: int, number of features in the sampled set + :return: + output: (B, npoint) tensor containing the set + """ + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + output = torch.cuda.IntTensor(B, npoint) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + + pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +furthest_point_sample = FurthestPointSampling.apply + + +class GatherOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) + :param idx: (B, npoint) index tensor of the features to gather + :return: + output: (B, C, npoint) + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, npoint = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, npoint) + + pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) + + ctx.for_backwards = (idx, C, N) + return output + + @staticmethod + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + B, npoint = idx.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + grad_out_data = grad_out.data.contiguous() + pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) + return grad_features, None + + +gather_operation = GatherOperation.apply + + +class ThreeNN(Function): + + @staticmethod + def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Find the three nearest neighbors of unknown in known + :param ctx: + :param unknown: (B, N, 3) + :param known: (B, M, 3) + :return: + dist: (B, N, 3) l2 distance to the three nearest neighbors + idx: (B, N, 3) index of 3 nearest neighbors + """ + assert unknown.is_contiguous() + assert known.is_contiguous() + + B, N, _ = unknown.size() + m = known.size(1) + dist2 = torch.cuda.FloatTensor(B, N, 3) + idx = torch.cuda.IntTensor(B, N, 3) + + pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Performs weight linear interpolation on 3 features + :param ctx: + :param features: (B, C, M) Features descriptors to be interpolated from + :param idx: (B, n, 3) three nearest neighbors of the target features in features + :param weight: (B, n, 3) weights + :return: + output: (B, C, N) tensor of the interpolated features + """ + assert features.is_contiguous() + assert idx.is_contiguous() + assert weight.is_contiguous() + + B, c, m = features.size() + n = idx.size(1) + ctx.three_interpolate_for_backward = (idx, weight, m) + output = torch.cuda.FloatTensor(B, c, n) + + pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, N) tensor with gradients of outputs + :return: + grad_features: (B, C, M) tensor with gradients of features + None: + None: + """ + idx, weight, m = ctx.three_interpolate_for_backward + B, c, n = grad_out.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) + grad_out_data = grad_out.data.contiguous() + + pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply + + +class GroupingOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) tensor of features to group + :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with + :return: + output: (B, C, npoint, nsample) tensor + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, nfeatures, nsample = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) + + ctx.for_backwards = (idx, N) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward + :return: + grad_features: (B, C, N) gradient of the features + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + + grad_out_data = grad_out.data.contiguous() + pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply + + +class BallQuery(Function): + + @staticmethod + def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param radius: float, radius of the balls + :param nsample: int, maximum number of features in the balls + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centers of the ball query + :return: + idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls + """ + assert new_xyz.is_contiguous() + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + npoint = new_xyz.size(1) + idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() + + pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply + + +class QueryAndGroup(nn.Module): + def __init__(self, radius: float, nsample: int, use_xyz: bool = True): + """ + :param radius: float, radius of ball + :param nsample: int, maximum number of features to gather in the ball + :param use_xyz: + """ + super().__init__() + self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centroids + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, 3 + C, npoint, nsample) + """ + idx = ball_query(self.radius, self.nsample, xyz, new_xyz) + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" + new_features = grouped_xyz + + return new_features + + +class GroupAll(nn.Module): + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: ignored + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, C + 3, 1, N) + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features diff --git a/lib/pointnet2/pytorch_utils.py b/lib/pointnet2/pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..09cb7bc76d88dde5757ac70b6e05e1e0c768cc1b --- /dev/null +++ b/lib/pointnet2/pytorch_utils.py @@ -0,0 +1,236 @@ +import torch.nn as nn +from typing import List, Tuple + + +class SharedMLP(nn.Sequential): + + def __init__( + self, + args: List[int], + *, + bn: bool = False, + activation=nn.ReLU(inplace=True), + preact: bool = False, + first: bool = False, + name: str = "", + instance_norm: bool = False, + ): + super().__init__() + + for i in range(len(args) - 1): + self.add_module( + name + 'layer{}'.format(i), + Conv2d( + args[i], + args[i + 1], + bn=(not first or not preact or (i != 0)) and bn, + activation=activation + if (not first or not preact or (i != 0)) else None, + preact=preact, + instance_norm=instance_norm + ) + ) + + +class _ConvBase(nn.Sequential): + + def __init__( + self, + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=None, + batch_norm=None, + bias=True, + preact=False, + name="", + instance_norm=False, + instance_norm_func=None + ): + super().__init__() + + bias = bias and (not bn) + conv_unit = conv( + in_size, + out_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias + ) + init(conv_unit.weight) + if bias: + nn.init.constant_(conv_unit.bias, 0) + + if bn: + if not preact: + bn_unit = batch_norm(out_size) + else: + bn_unit = batch_norm(in_size) + if instance_norm: + if not preact: + in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) + else: + in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) + + if preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + self.add_module(name + 'conv', conv_unit) + + if not preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + +class _BNBase(nn.Sequential): + + def __init__(self, in_size, batch_norm=None, name=""): + super().__init__() + self.add_module(name + "bn", batch_norm(in_size)) + + nn.init.constant_(self[0].weight, 1.0) + nn.init.constant_(self[0].bias, 0) + + +class BatchNorm1d(_BNBase): + + def __init__(self, in_size: int, *, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) + + +class BatchNorm2d(_BNBase): + + def __init__(self, in_size: int, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) + + +class Conv1d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv1d, + batch_norm=BatchNorm1d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm1d + ) + + +class Conv2d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv2d, + batch_norm=BatchNorm2d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm2d + ) + + +class FC(nn.Sequential): + + def __init__( + self, + in_size: int, + out_size: int, + *, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=None, + preact: bool = False, + name: str = "" + ): + super().__init__() + + fc = nn.Linear(in_size, out_size, bias=not bn) + if init is not None: + init(fc.weight) + if not bn: + nn.init.constant(fc.bias, 0) + + if preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(in_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + + self.add_module(name + 'fc', fc) + + if not preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(out_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + diff --git a/lib/pointnet2/setup.py b/lib/pointnet2/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..99e59e37b90517cc38c35d100f7f9cee0e309368 --- /dev/null +++ b/lib/pointnet2/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='pointnet2', + ext_modules=[ + CUDAExtension('pointnet2_cuda', [ + 'src/pointnet2_api.cpp', + + 'src/ball_query.cpp', + 'src/ball_query_gpu.cu', + 'src/group_points.cpp', + 'src/group_points_gpu.cu', + 'src/interpolate.cpp', + 'src/interpolate_gpu.cu', + 'src/sampling.cpp', + 'src/sampling_gpu.cu', + ], + extra_compile_args={'cxx': ['-g'], + 'nvcc': ['-O2']}) + ], + cmdclass={'build_ext': BuildExtension} +) diff --git a/lib/pointnet2/src/ball_query.cpp b/lib/pointnet2/src/ball_query.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9b176e5da5dd89a3378652f0b806925e8ee8996 --- /dev/null +++ b/lib/pointnet2/src/ball_query.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#include +#include +#include +#include "ball_query_gpu.h" + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) + +int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, + at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { + CHECK_INPUT(new_xyz_tensor); + CHECK_INPUT(xyz_tensor); + const float *new_xyz = new_xyz_tensor.data(); + const float *xyz = xyz_tensor.data(); + int *idx = idx_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); + return 1; +} diff --git a/lib/pointnet2/src/ball_query_gpu.cu b/lib/pointnet2/src/ball_query_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..f8840aa6650693cea17d337008a15fef13ec1ebc --- /dev/null +++ b/lib/pointnet2/src/ball_query_gpu.cu @@ -0,0 +1,67 @@ +#include +#include +#include + +#include "ball_query_gpu.h" +#include "cuda_utils.h" + + +__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, + const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= m) return; + + new_xyz += bs_idx * m * 3 + pt_idx * 3; + xyz += bs_idx * n * 3; + idx += bs_idx * m * nsample + pt_idx * nsample; + + float radius2 = radius * radius; + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + int cnt = 0; + for (int k = 0; k < n; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + if (d2 < radius2){ + if (cnt == 0){ + for (int l = 0; l < nsample; ++l) { + idx[l] = k; + } + } + idx[cnt] = k; + ++cnt; + if (cnt >= nsample) break; + } + } +} + + +void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ + const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + + cudaError_t err; + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} \ No newline at end of file diff --git a/lib/pointnet2/src/ball_query_gpu.h b/lib/pointnet2/src/ball_query_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..ffc831a8b700f46b50e0b90d49c538aa0fedca50 --- /dev/null +++ b/lib/pointnet2/src/ball_query_gpu.h @@ -0,0 +1,15 @@ +#ifndef _BALL_QUERY_GPU_H +#define _BALL_QUERY_GPU_H + +#include +#include +#include +#include + +int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, + at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); + +void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, + const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); + +#endif diff --git a/lib/pointnet2/src/cuda_utils.h b/lib/pointnet2/src/cuda_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7fe27969179c976a88199bbe962ca4f8d97263a4 --- /dev/null +++ b/lib/pointnet2/src/cuda_utils.h @@ -0,0 +1,15 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include + +#define TOTAL_THREADS 1024 +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} +#endif diff --git a/lib/pointnet2/src/group_points.cpp b/lib/pointnet2/src/group_points.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa80f0e318acc57dabf76ec0a8b1d9dff482ab89 --- /dev/null +++ b/lib/pointnet2/src/group_points.cpp @@ -0,0 +1,34 @@ +#include +#include +#include +#include +#include "group_points_gpu.h" +#include +#include + + + +int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { + + float *grad_points = grad_points_tensor.data(); + const int *idx = idx_tensor.data(); + const float *grad_out = grad_out_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); + return 1; +} + + +int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { + + const float *points = points_tensor.data(); + const int *idx = idx_tensor.data(); + float *out = out_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); + return 1; +} diff --git a/lib/pointnet2/src/group_points_gpu.cu b/lib/pointnet2/src/group_points_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..c015a8125e38aafa1f960000044978463b7853b1 --- /dev/null +++ b/lib/pointnet2/src/group_points_gpu.cu @@ -0,0 +1,86 @@ +#include +#include + +#include "cuda_utils.h" +#include "group_points_gpu.h" + + +__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, + const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); +} + +void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + cudaError_t err; + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, + const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + int in_idx = bs_idx * c * n + c_idx * n + idx[0]; + int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + out[out_idx] = points[in_idx]; +} + + +void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, float *out, cudaStream_t stream) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + cudaError_t err; + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/lib/pointnet2/src/group_points_gpu.h b/lib/pointnet2/src/group_points_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..76c73ca2600ef75c192b06d28f79a168f1ba368b --- /dev/null +++ b/lib/pointnet2/src/group_points_gpu.h @@ -0,0 +1,22 @@ +#ifndef _GROUP_POINTS_GPU_H +#define _GROUP_POINTS_GPU_H + +#include +#include +#include +#include + + +int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); + +void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, float *out, cudaStream_t stream); + +int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); + +void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); + +#endif diff --git a/lib/pointnet2/src/interpolate.cpp b/lib/pointnet2/src/interpolate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88d837f966f52696308b7d85ec1756b2395bb986 --- /dev/null +++ b/lib/pointnet2/src/interpolate.cpp @@ -0,0 +1,53 @@ +#include +#include +#include +#include +#include +#include +#include +#include "interpolate_gpu.h" +#include +#include + + +void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, + at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { + const float *unknown = unknown_tensor.data(); + const float *known = known_tensor.data(); + float *dist2 = dist2_tensor.data(); + int *idx = idx_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); +} + + +void three_interpolate_wrapper_fast(int b, int c, int m, int n, + at::Tensor points_tensor, + at::Tensor idx_tensor, + at::Tensor weight_tensor, + at::Tensor out_tensor) { + + const float *points = points_tensor.data(); + const float *weight = weight_tensor.data(); + float *out = out_tensor.data(); + const int *idx = idx_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); +} + +void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, + at::Tensor grad_out_tensor, + at::Tensor idx_tensor, + at::Tensor weight_tensor, + at::Tensor grad_points_tensor) { + + const float *grad_out = grad_out_tensor.data(); + const float *weight = weight_tensor.data(); + float *grad_points = grad_points_tensor.data(); + const int *idx = idx_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); +} diff --git a/lib/pointnet2/src/interpolate_gpu.cu b/lib/pointnet2/src/interpolate_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..a123dd8d8d4f5ed23cc4a340abb1141d140fca3c --- /dev/null +++ b/lib/pointnet2/src/interpolate_gpu.cu @@ -0,0 +1,161 @@ +#include +#include +#include + +#include "cuda_utils.h" +#include "interpolate_gpu.h" + + +__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, + const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= n) return; + + unknown += bs_idx * n * 3 + pt_idx * 3; + known += bs_idx * m * 3; + dist2 += bs_idx * n * 3 + pt_idx * 3; + idx += bs_idx * n * 3 + pt_idx * 3; + + float ux = unknown[0]; + float uy = unknown[1]; + float uz = unknown[2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + float x = known[k * 3 + 0]; + float y = known[k * 3 + 1]; + float z = known[k * 3 + 2]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; besti3 = besti2; + best2 = best1; besti2 = besti1; + best1 = d; besti1 = k; + } + else if (d < best2) { + best3 = best2; besti3 = besti2; + best2 = d; besti2 = k; + } + else if (d < best3) { + best3 = d; besti3 = k; + } + } + dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; + idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; +} + + +void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx, cudaStream_t stream) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, + const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + weight += bs_idx * n * 3 + pt_idx * 3; + points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + out += bs_idx * c * n + c_idx * n; + + out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; +} + +void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, + const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + grad_out += bs_idx * c * n + c_idx * n + pt_idx; + weight += bs_idx * n * 3 + pt_idx * 3; + grad_points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + + + atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); + atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); + atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); +} + +void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, + const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} \ No newline at end of file diff --git a/lib/pointnet2/src/interpolate_gpu.h b/lib/pointnet2/src/interpolate_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..f1771087c5e4146e3c5775d3b929ebffffd11ccb --- /dev/null +++ b/lib/pointnet2/src/interpolate_gpu.h @@ -0,0 +1,30 @@ +#ifndef _INTERPOLATE_GPU_H +#define _INTERPOLATE_GPU_H + +#include +#include +#include +#include + + +void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, + at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); + +void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx, cudaStream_t stream); + + +void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, + at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); + +void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, + const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); + + +void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, + at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); + +void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, + const int *idx, const float *weight, float *grad_points, cudaStream_t stream); + +#endif diff --git a/lib/pointnet2/src/pointnet2_api.cpp b/lib/pointnet2/src/pointnet2_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d91f0f2176a6080624f071e5535fe509a0ac83c4 --- /dev/null +++ b/lib/pointnet2/src/pointnet2_api.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include "ball_query_gpu.h" +#include "group_points_gpu.h" +#include "sampling_gpu.h" +#include "interpolate_gpu.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); + + m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); + m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); + + m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); + m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); + + m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); + + m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); + m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); + m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); +} diff --git a/lib/pointnet2/src/sampling.cpp b/lib/pointnet2/src/sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f54daa763ed66240c17ba6254ee9d5a39b6dfc0 --- /dev/null +++ b/lib/pointnet2/src/sampling.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include +#include +#include "sampling_gpu.h" + + + +int gather_points_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ + const float *points = points_tensor.data(); + const int *idx = idx_tensor.data(); + float *out = out_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); + return 1; +} + + +int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { + + const float *grad_out = grad_out_tensor.data(); + const int *idx = idx_tensor.data(); + float *grad_points = grad_points_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); + return 1; +} + + +int furthest_point_sampling_wrapper(int b, int n, int m, + at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { + + const float *points = points_tensor.data(); + float *temp = temp_tensor.data(); + int *idx = idx_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); + return 1; +} diff --git a/lib/pointnet2/src/sampling_gpu.cu b/lib/pointnet2/src/sampling_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e49a60dd6a80449be4c6c0d0d710be7b5fe9cd5 --- /dev/null +++ b/lib/pointnet2/src/sampling_gpu.cu @@ -0,0 +1,253 @@ +#include +#include + +#include "cuda_utils.h" +#include "sampling_gpu.h" + + +__global__ void gather_points_kernel_fast(int b, int c, int n, int m, + const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, M) + // output: + // out: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + points += bs_idx * c * n + c_idx * n; + out[0] = points[idx[0]]; +} + +void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *points, const int *idx, float *out, cudaStream_t stream) { + // points: (B, C, N) + // idx: (B, npoints) + // output: + // out: (B, C, npoints) + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, float *__restrict__ grad_points) { + // grad_out: (B, C, M) + // idx: (B, M) + // output: + // grad_points: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + grad_out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + grad_points += bs_idx * c * n + c_idx * n; + + atomicAdd(grad_points + idx[0], grad_out[0]); +} + +void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, npoints) + // idx: (B, npoints) + // output: + // grad_points: (B, C, N) + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +template +__global__ void furthest_point_sampling_kernel(int b, int n, int m, + const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + if (m <= 0) return; + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int batch_index = blockIdx.x; + dataset += batch_index * n * 3; + temp += batch_index * n; + idxs += batch_index * m; + + int tid = threadIdx.x; + const int stride = block_size; + + int old = 0; + if (threadIdx.x == 0) + idxs[0] = old; + + __syncthreads(); + for (int j = 1; j < m; j++) { + int besti = 0; + float best = -1; + float x1 = dataset[old * 3 + 0]; + float y1 = dataset[old * 3 + 1]; + float z1 = dataset[old * 3 + 2]; + for (int k = tid; k < n; k += stride) { + float x2, y2, z2; + x2 = dataset[k * 3 + 0]; + y2 = dataset[k * 3 + 1]; + z2 = dataset[k * 3 + 2]; + // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + // if (mag <= 1e-3) + // continue; + + float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, temp[k]); + temp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 1024) { + if (tid < 512) { + __update(dists, dists_i, tid, tid + 512); + } + __syncthreads(); + } + + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) + idxs[j] = old; + } +} + +void furthest_point_sampling_kernel_launcher(int b, int n, int m, + const float *dataset, float *temp, int *idxs, cudaStream_t stream) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + cudaError_t err; + unsigned int n_threads = opt_n_threads(n); + + switch (n_threads) { + case 1024: + furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; + case 512: + furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; + case 256: + furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; + case 128: + furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; + case 64: + furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; + case 32: + furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; + case 16: + furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; + case 8: + furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; + case 4: + furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; + case 2: + furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; + case 1: + furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; + default: + furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); + } + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/lib/pointnet2/src/sampling_gpu.h b/lib/pointnet2/src/sampling_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..6200c5914e434ecd2fc3b36313985805f6dbe0cc --- /dev/null +++ b/lib/pointnet2/src/sampling_gpu.h @@ -0,0 +1,29 @@ +#ifndef _SAMPLING_GPU_H +#define _SAMPLING_GPU_H + +#include +#include +#include + + +int gather_points_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); + +void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *points, const int *idx, float *out, cudaStream_t stream); + + +int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); + +void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); + + +int furthest_point_sampling_wrapper(int b, int n, int m, + at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); + +void furthest_point_sampling_kernel_launcher(int b, int n, int m, + const float *dataset, float *temp, int *idxs, cudaStream_t stream); + +#endif diff --git a/model/LLM/__init__.py b/model/LLM/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8eb9e9325f1906f28a9d60d967ff76963ff1a8 --- /dev/null +++ b/model/LLM/__init__.py @@ -0,0 +1 @@ +from . import onellm \ No newline at end of file diff --git a/model/LLM/__pycache__/__init__.cpython-310.pyc b/model/LLM/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e70f6416d504770062ceb50661a6094181c47ea2 Binary files /dev/null and b/model/LLM/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/LLM/__pycache__/__init__.cpython-39.pyc b/model/LLM/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be815601063517a817e23e64c9a6208e4e66d833 Binary files /dev/null and b/model/LLM/__pycache__/__init__.cpython-39.pyc differ diff --git a/model/LLM/__pycache__/onellm.cpython-310.pyc b/model/LLM/__pycache__/onellm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccf829243e41031a865186ba965b5d98d44174f1 Binary files /dev/null and b/model/LLM/__pycache__/onellm.cpython-310.pyc differ diff --git a/model/LLM/__pycache__/onellm.cpython-39.pyc b/model/LLM/__pycache__/onellm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..320f3ae803542ebcea9d3414a83ae4f5e5845455 Binary files /dev/null and b/model/LLM/__pycache__/onellm.cpython-39.pyc differ diff --git a/model/LLM/onellm.py b/model/LLM/onellm.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5195737c0448e3d83c3301acbd3fce3bcd0a4e --- /dev/null +++ b/model/LLM/onellm.py @@ -0,0 +1,495 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Optional, Tuple +from dataclasses import dataclass +import math +import functools +import copy + +import torch +from torch import nn +import torch.nn.functional as F + +import fairscale.nn.model_parallel.initialize as fs_init +from fairscale.nn.model_parallel.layers import ( + ParallelEmbedding, + RowParallelLinear, + ColumnParallelLinear, +) +from ..components import RMSNorm +from flash_attn import flash_attn_func + +import open_clip + + +default_linear_init = nn.init.xavier_uniform_ + + +@dataclass +class ModelArgs: + dim: int = 512 + n_layers: int = 8 + n_heads: int = 8 + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - + 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size() + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=default_linear_init, + ) + self.wk = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=default_linear_init, + ) + self.wv = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=default_linear_init, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + input_is_parallel=True, + init_method=default_linear_init, + ) + + self.flash = True + self.k_cache, self.v_cache = None, None + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + if freqs_cis is not None: + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + if self.k_cache is None or self.v_cache is None: + keys, values = xk, xv + else: + self.k_cache = self.k_cache.to(xk) + self.v_cache = self.v_cache.to(xv) + self.k_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xk + self.v_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xv + keys = self.k_cache[:bsz, :start_pos + seqlen] + values = self.v_cache[:bsz, :start_pos + seqlen] + + output = flash_attn_func( + xq, keys, values, dropout_p=0.0, causal=mask is not None) + output = output.contiguous().view(bsz, seqlen, -1) + + return self.wo(output) + + def allocate_kv_cache(self, max_batch_size: int, max_seq_len: int) -> None: + kv_cache_shape = (max_batch_size, max_seq_len, + self.n_local_heads, self.head_dim) + if self.k_cache is None or self.k_cache.size() != kv_cache_shape: + self.k_cache = torch.empty(kv_cache_shape) + if self.v_cache is None or self.v_cache.size() != kv_cache_shape: + self.v_cache = torch.empty(kv_cache_shape) + + def destroy_kv_cache(self) -> None: + self.k_cache, self.v_cache = None, None + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * \ + ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init, + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=default_linear_init + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init + ) + + def _silu_gating(self, x, y): + return F.silu(x) * y + + def forward(self, x): + return self.w2(self._silu_gating(self.w1(x), self.w3(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def _forward_ffn(self, h): + return h + self.feed_forward(self.ffn_norm(h)) + + def _forward_attention(self, x, start_pos, freqs_cis, mask, prompt): + return x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt) + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None): + h = self._forward_attention(x, start_pos, freqs_cis, mask, prompt) + out = self._forward_ffn(h) + return out + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + self.tok_embeddings = ParallelEmbedding( + params.vocab_size, params.dim, init_method=nn.init.normal_, + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, init_method=default_linear_init, + ) + + self.freqs_cis = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ) + + # load clip + self.clip, _, _ = open_clip.create_model_and_transforms( + 'ViT-L-14', pretrained='openai') + for param in self.clip.parameters(): + param.requires_grad = False + param.data = param.data.half() + self.clip.transformer = None + + self.image_words = 30 + self.cache_image_words = 0 # for inference + + clip_width = self.clip.visual.conv1.out_channels + # create modal shared modules + self.resample_layers = nn.ModuleDict() + self.num_experts = 3 + self.num_resample_layers = 8 + for expert in range(self.num_experts): + expert = str(expert) + self.resample_layers[expert] = nn.ModuleList() + resampler_params = copy.deepcopy(params) + resampler_params.n_heads = 16 + resampler_params.dim = clip_width + for layer_id in range(self.num_resample_layers): + self.resample_layers[expert].append( + TransformerBlock(layer_id, resampler_params)) + + self.conv1 = nn.ModuleDict() + self.positional_embedding = nn.ParameterDict() + self.resample_tokens = nn.ParameterDict() + self.clip_proj1 = nn.ModuleDict() + self.clip_proj2 = nn.ModuleDict() + self.routers = nn.ModuleDict() + self.start_tag = nn.ParameterDict() + self.end_tag = nn.ParameterDict() + # self.modals = ['image', 'audio', 'point', 'video', 'rgbd', 'rgbn', 'fmri', 'imu'] + self.modals = ['image', 'audio', 'video', 'rgbd', 'rgbn', 'fmri', 'imu'] + for modal in self.modals: + if modal in ['image', 'video', 'rgbn', 'rgbn']: + modal_tokens = 256 + 1 + pass + elif modal == 'audio': + self.conv1[modal] = nn.Conv2d( + 1, clip_width, kernel_size=(16, 16), stride=(10, 10)) + modal_tokens = 1212 + 1 + self.positional_embedding[modal] = nn.Parameter( + torch.empty([modal_tokens, clip_width])) + nn.init.normal_(self.positional_embedding[modal], std=0.02) + elif modal == 'point': + from lib.point_utils import PointPatchEmbed + self.conv1[modal] = PointPatchEmbed( + in_channels=6, channels=clip_width) + modal_tokens = 1024 + 1 + self.positional_embedding[modal] = nn.Parameter( + torch.empty([modal_tokens, clip_width])) + nn.init.normal_(self.positional_embedding[modal], std=0.02) + elif modal == 'fmri': + self.conv1[modal] = nn.Linear(15724, 8192) + self.positional_embedding[modal] = nn.Parameter( + torch.empty([8+1, clip_width])) + nn.init.normal_(self.positional_embedding[modal], std=0.02) + elif modal == 'imu': + self.conv1[modal] = nn.Conv1d( + in_channels=6, out_channels=clip_width, kernel_size=10, bias=False) + self.positional_embedding[modal] = nn.Parameter( + torch.empty([391+1, clip_width])) + nn.init.normal_(self.positional_embedding[modal], std=0.02) + + self.routers[modal] = Mlp( + clip_width, clip_width * 4, self.num_experts) + + self.resample_tokens[modal] = nn.Parameter( + torch.empty([1, 30, resampler_params.dim])) + nn.init.normal_(self.resample_tokens[modal], std=0.02) + + self.clip_proj1[modal] = nn.Sequential( + nn.Linear(clip_width, resampler_params.dim), + nn.LayerNorm(resampler_params.dim)) + + self.clip_proj2[modal] = nn.Sequential( + nn.Linear(resampler_params.dim, params.dim), + nn.LayerNorm(params.dim)) + + self.start_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim)) + self.end_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim)) + + # @torch.no_grad() + + def clip_encode_image(self, x, modal='image'): + # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, + x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + + # use pretrained pos embeding for rest modalities + pos_embedding = self.clip.visual.positional_embedding + if modal in ['audio', 'point', 'fmri', 'imu']: + pos_embedding = self.positional_embedding[modal] + + x = x + pos_embedding.to(x.dtype) + x = self.clip.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + # preserve all spatial tokens + x = self.clip.visual.ln_post(x[:, :, :]) + + # if self.clip.visual.proj is not None: + # x = x @ self.clip.visual.proj + + return x + + def encode_image(self, x, modal='image'): + bsz = x.size(0) + T = 1 + if modal in ['image']: + # modified from CLIP + x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid] + elif modal in ['audio', 'imu']: + x = self.conv1[modal](x) + elif modal == 'point': + # [B, 16384, 6] -> [B, 1024, 1024, 1] + x = self.conv1[modal](x.float()).to(x.dtype) + elif modal in ['video', 'rgbd', 'rgbn']: + # [B, 15, 3, 224, 224] + B, T = x.shape[:2] + bsz = B * T + x = x.reshape(bsz, *x.shape[2:]) + x = self.clip.visual.conv1(x) + elif modal == 'fmri': + x = self.conv1[modal](x) + # [B, 1, 8196] -> [B, 1024, 8] + x = x.reshape(x.size(0), self.clip.visual.conv1.out_channels, -1) + + image_feats = self.clip_encode_image(x, modal=modal) + # take mean on time dimension + # all inputs are reduced to [B, L, D] + bsz = int(bsz / T) + image_feats = image_feats.reshape( + bsz, T, *image_feats.shape[1:]).mean(dim=1) + + image_feats = self.clip_proj1[modal](image_feats) + image_feats = torch.cat( + [self.resample_tokens[modal].repeat(bsz, 1, 1), image_feats], dim=1) + + # routing modalites + # [B, L, D]->[B, L, N] + routing_weights = self.routers[modal](image_feats).sigmoid() + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + + image_feats_experts = [] + for expert_id in range(self.num_experts): + image_feats_expert = image_feats + for layer in self.resample_layers[str(expert_id)]: + image_feats_expert = layer(image_feats_expert, 0, None, None) + + image_feats_expert = image_feats_expert[:, :self.resample_tokens[modal].size(1)] + routing_weight = routing_weights[:, :self.resample_tokens[modal].size( + 1), expert_id] + # [B, L, D] * [B, L, 1] + image_feats_expert = image_feats_expert * routing_weight[:, :, None] + + image_feats_experts.append(image_feats_expert) + + image_feats = sum(image_feats_experts) + image_feats = self.clip_proj2[modal](image_feats) + + return image_feats + + def forward(self, examples, image=None, modal='image'): + self._destroy_kv_cache() # training always disables kv cache + modal = modal[0] + _bsz, seqlen = examples.shape + h = self.tok_embeddings(examples) + self.freqs_cis = self.freqs_cis.to(h.device) + + start_pos = 0 + prefix_len = 0 + if image is not None: + h_bos, h_caption = h[:, :1], h[:, 1:] + image_tokens = self.encode_image(image, modal) + h = torch.cat((h_bos, self.start_tag[modal].expand( + _bsz, -1, -1), image_tokens, self.end_tag[modal].expand(_bsz, -1, -1), h_caption), dim=1) + # bos + image token + start_tag[modal], end_tag[modal] is used for caption generation + prefix_len = image_tokens.shape[1] + 1 + 1 + seqlen = h.shape[1] + + freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen] + mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h[:, prefix_len:, :]) + return output + + @torch.inference_mode() + def forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, modal='image'): + modal = modal[0] if isinstance(modal, list) else modal + _bsz, seqlen = tokens.shape + if start_pos == 0: + # kv cache will not re-allocate if size is unchanged + self._allocate_kv_cache(_bsz) + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + + if image is not None: + h_bos, h_caption = h[:, :1], h[:, 1:] + image_tokens = self.encode_image(image, modal) + self.cache_image_words = image_tokens.shape[1] + h = torch.cat((h_bos, self.start_tag[modal].repeat(_bsz, 1, 1), image_tokens, self.end_tag[modal].repeat(_bsz, 1, 1), h_caption), dim=1) + seqlen = h.shape[1] + freqs_cis = self.freqs_cis[0: seqlen] + else: + if start_pos == 0: + self.cache_image_words = 0 + freqs_cis = self.freqs_cis[0: seqlen] + else: + # if image was not None when start_pos=0, + # the offset should be added to start_pos within later forward_inference calls + start_pos = start_pos + self.cache_image_words + freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen] + + # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h[:, -1, :]) # only compute last logits + return output.float() + + def _allocate_kv_cache(self, max_batch_size: int) -> None: + for layer in self.layers: + layer.attention.allocate_kv_cache( + max_batch_size, self.params.max_seq_len) + + def _destroy_kv_cache(self) -> None: + for layer in self.layers: + layer.attention.destroy_kv_cache() diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/__pycache__/__init__.cpython-310.pyc b/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab67f64cfe739a7a1c51327e5e7a0ea2afc50cd9 Binary files /dev/null and b/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea7e8cd12224f18eb3eefbc92abf61852979fab Binary files /dev/null and b/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/model/__pycache__/components.cpython-39.pyc b/model/__pycache__/components.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfbf25224cf34dab4fa2f85fff462d2dbef6b4d6 Binary files /dev/null and b/model/__pycache__/components.cpython-39.pyc differ diff --git a/model/__pycache__/meta.cpython-310.pyc b/model/__pycache__/meta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fbd547ed81e5dc062ca75c125fe8c8a668b5ead Binary files /dev/null and b/model/__pycache__/meta.cpython-310.pyc differ diff --git a/model/__pycache__/meta.cpython-39.pyc b/model/__pycache__/meta.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b69c01b1098d55637fb39f5be3aed62ddaf7cf43 Binary files /dev/null and b/model/__pycache__/meta.cpython-39.pyc differ diff --git a/model/__pycache__/tokenizer.cpython-310.pyc b/model/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4452629f6f6edcb5522834a8e5bbdfc825b48e Binary files /dev/null and b/model/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/model/__pycache__/tokenizer.cpython-39.pyc b/model/__pycache__/tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a8d58048c6364146decfe9c883d63dc197e359 Binary files /dev/null and b/model/__pycache__/tokenizer.cpython-39.pyc differ diff --git a/model/components.py b/model/components.py new file mode 100644 index 0000000000000000000000000000000000000000..2c8bc4e88484950988aaad4faf6d34ec1a4ec8bf --- /dev/null +++ b/model/components.py @@ -0,0 +1,57 @@ +import warnings +import torch +import torch.nn as nn + +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except ImportError: + warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + + class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + + + diff --git a/model/meta.py b/model/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ab6daaa14337633f9d3261d78248683d04c930 --- /dev/null +++ b/model/meta.py @@ -0,0 +1,175 @@ +from typing import List +import torch +import torch.nn as nn +import json +import os +from .tokenizer import Tokenizer +from . import LLM + +from fairscale.nn.model_parallel import initialize as fs_init + + +class MetaModel(nn.Module): + + def __init__(self, llama_type, llama_config, llama_ckpt_dir=None, tokenizer_path=None): + super().__init__() + + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0) + + ModelArgs = LLM.__dict__[llama_type].ModelArgs + Transformer = LLM.__dict__[llama_type].Transformer + + with open(llama_config, "r") as f: + params = json.loads(f.read()) + model_args: ModelArgs = ModelArgs( + max_seq_len=2048, max_batch_size=32, **params + ) + self.tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = self.tokenizer.n_words + + model = Transformer(model_args) + mp_rank = fs_init.get_model_parallel_rank() + if llama_ckpt_dir is not None: + ckpt_path = os.path.join(llama_ckpt_dir, f"consolidated.{mp_rank:02d}.pth") + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location="cpu") + msg = model.load_state_dict(checkpoint, strict=False) + print(msg) + else: + print(f'Checkpoint not found at {ckpt_path}') + self.llma = model + for name, param in self.named_parameters(): + if param.requires_grad: + print(f"Trainable param: {name}, {param.shape}, {param.dtype}") + count = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f"Parameter count : {count}") + + def forward(self, examples, labels, image=None, modal='image'): + output = self.llma(examples, image=image, modal=modal) + output = output[:, :-1, :] + labels = labels[:, 1:] + + if labels.sum() == 0: + c_loss = output.mean() * 0 + else: + c_loss = self.criterion(output.reshape(-1, 32000), labels.flatten()) + + return c_loss + + def generate( + self, + prompts: List[str], + images, + max_gen_len: int, + temperature: float = 0.8, + top_p: float = 0.95, + modal = ['image'], + ) -> List[str]: + bsz = len(prompts) + params = self.llma.params + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + prompt_tokens = [self.tokenizer.encode( + x, bos=True, eos=False) for x in prompts] + + min_prompt_size = min([len(t) for t in prompt_tokens]) + max_prompt_size = max([len(t) for t in prompt_tokens]) + + total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) + + tokens = torch.full( + (bsz, total_len), self.tokenizer.pad_id).cuda().long() + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t).long() + input_text_mask = tokens != self.tokenizer.pad_id + start_pos = min_prompt_size + prev_pos = 0 + for cur_pos in range(start_pos, total_len): + logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal=modal) + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token = self.sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + prev_pos = cur_pos + + decoded = [] + for i, t in enumerate(tokens.tolist()): + # cut to max gen len + t = t[: len(prompt_tokens[i]) + max_gen_len] + # cut to eos tok if any + try: + t = t[: t.index(self.tokenizer.eos_id)] + except ValueError: + pass + decoded.append(self.tokenizer.decode(t)) + return decoded + + @torch.inference_mode() + def stream_generate( + self, + prompt: str, + images, + max_gen_len: int, + temperature: float = 0.8, + top_p: float = 0.95, + modal = ['image'], + ): + params = self.llma.params + + prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) + # truncate from the left. leave some space for generation. + max_seq_len = params.max_seq_len + if images is not None: + max_seq_len -= self.llma.image_words + + max_prompt_size = max_seq_len - max_gen_len + prompt_tokens = prompt_tokens[-max_prompt_size:] + + prompt_size = len(prompt_tokens) + + total_len = min(max_seq_len, max_gen_len + prompt_size) + + tokens = torch.full([total_len], 0).cuda().long() + + tokens[:len(prompt_tokens)] = torch.tensor(prompt_tokens).long() + start_pos = prompt_size + prev_pos = 0 + generate_until = start_pos + for cur_pos in range(start_pos, total_len): + logits = self.llma.forward_inference(tokens[None, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal = modal) + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token = self.sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + next_token = next_token.item() + + if next_token == self.tokenizer.eos_id: + break + + tokens[cur_pos] = next_token + prev_pos = cur_pos + generate_until = cur_pos + 1 + yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": False} + + yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": True} + + def sample_top_p(self, probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + def get_image_words(self): + return self.llma.image_words \ No newline at end of file diff --git a/model/tokenizer.py b/model/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4315856eea5c4318499c8909898252902252f30 --- /dev/null +++ b/model/tokenizer.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from sentencepiece import SentencePieceProcessor +from logging import getLogger +from typing import List +import os + + +logger = getLogger() + + +class Tokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) diff --git a/requirements.txt b/requirements.txt index ff5a872a773e1619013dc49c7be53ad722943b40..ce74fdd1f2242fc4e7bc50f084f7030081836fa5 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,13 @@ ---extra-index-url https://download.pytorch.org/whl/cu113 -torch==1.12.0+cu113 +--extra-index-url https://download.pytorch.org/whl/cu117 +torch==2.0.0+cu117 +packaging fairscale sentencepiece Pillow huggingface_hub -git+https://github.com/csuhan/timm_0_3_2.git -git+https://github.com/openai/CLIP.git \ No newline at end of file +open_clip_torch +pytorchvideo==0.1.5 +torchaudio +matplotlib +flash-attn +gradio \ No newline at end of file diff --git a/util/__pycache__/misc.cpython-310.pyc b/util/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4caa262729d9934200c8f44f3ca67d0913580474 Binary files /dev/null and b/util/__pycache__/misc.cpython-310.pyc differ diff --git a/util/__pycache__/misc.cpython-39.pyc b/util/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f464da6e5db2871e7f85f496f2a0df542be9804 Binary files /dev/null and b/util/__pycache__/misc.cpython-39.pyc differ diff --git a/util/lr_sched.py b/util/lr_sched.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4624f4fb441ea7e37e50857813cb149887a0c0 --- /dev/null +++ b/util/lr_sched.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +def adjust_learning_rate(optimizer, it, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps + lr = args.lr * it / args.warmup_iters + elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate + lr = args.min_lr + else: # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + lr = args.min_lr + (args.lr - args.min_lr) * coeff + + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr + + +def adjust_learning_rate_epoch(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ + (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr + diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..cea0d87e40afd9b8be34ef99da7b1409cb1e43ba --- /dev/null +++ b/util/misc.py @@ -0,0 +1,516 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import glob +import time +from collections import defaultdict, deque +from pathlib import Path +import subprocess + +import torch +import torch.distributed as dist +from torch import inf +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + FullStateDictConfig, +) +from torch.distributed._shard.api import load_with_process_group + +from fairscale.nn.model_parallel import initialize as fs_init + +from types import TracebackType +from typing import Any, Optional +import torch +import torch.nn as nn + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, start_iter=0): + i = start_iter + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + log_msg = [ + header, + '[{0' + '}/{1}]', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + try: + total_len = len(iterable) + except: + total_len = "unknown" + if torch.cuda.is_available(): + print(log_msg.format( + i, total_len, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, total_len, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) +# force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + os.environ['MASTER_PORT'] = '8994' + while 'MASTER_ADDR' not in os.environ or len(os.environ['MASTER_ADDR'].strip()) == 0: + os.environ['MASTER_ADDR'] = subprocess.check_output('sinfo -Nh -n %s | head -n 1 | awk \'{print $1}\'' % os.environ['SLURM_NODELIST'], shell=True, ).decode().strip() + time.sleep(1) + print(os.environ['MASTER_ADDR']) + args.world_size = int(os.environ['SLURM_NPROCS']) + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + args.local_rank = args.gpu + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['WORLD_SIZE'] = str(args.world_size) + os.environ['RANK'] = str(args.rank) + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def init_distributed_mode1(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self, args): + self._scaler = ShardedGradScaler(enabled=args.precision in ["fp16"]) + + def __call__(self, loss, optimizer, model, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + if update_grad: + self._scaler.scale(loss).backward(create_graph=create_graph) + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + # norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + norm = model.clip_grad_norm_(clip_grad) + else: + raise NotImplementedError("please set clip_grad to a very large value if you do not want to clip.") + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + with model.no_sync(): + self._scaler.scale(loss).backward(create_graph=create_graph) + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def save_model(output_dir, args, epoch, iteration, model, optimizer, loss_scaler, dataset_state): + save_dir = os.path.join(output_dir, f"epoch_{epoch}_iter_{iteration:09d}") + os.makedirs(save_dir, exist_ok=True) + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + to_save = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "iter": iteration, + "epoch": epoch, + "scaler": loss_scaler.state_dict(), + "args": args, + "dataset_state": dataset_state, + } + save_path = os.path.join( + save_dir, + f"checkpoint.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth", + ) + torch.save(to_save, save_path) + + if args.save_consolidated: + mp_rank = fs_init.get_model_parallel_rank() + mp_world_size = fs_init.get_model_parallel_world_size() + consolidated_model_save_path = os.path.join( + save_dir, + f"consolidated.{mp_rank:02d}-of-{mp_world_size:02d}.pth", + ) + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=True, offload_to_cpu=True), + ): + save_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "tf32": torch.float32, + }[args.precision] + consolidated_model_state_dict = { + k: v.to(save_dtype) for k, v in model.state_dict().items() + } + if fs_init.get_data_parallel_rank() == 0: + torch.save(consolidated_model_state_dict, consolidated_model_save_path) + + # remove previous ckpts + ckpts = glob.glob(os.path.join(output_dir, "iter_*")) + glob.glob(os.path.join(output_dir, "epoch_*")) + ckpts.sort() + if len(ckpts)>2 and not args.keep_all: + for ckpt in ckpts[:-2]: + print('del', ckpt) + os.system(f'rm {ckpt} -rf') + +def load_model(args, model, optimizer, loss_scaler): + start_iter = 0 + start_epoch = 0 + if args.auto_resume: + ckpt_dirs = glob.glob(os.path.join(args.output_dir, "iter_*")) + glob.glob(os.path.join(args.output_dir, "epoch_*")) + ckpt_dirs.sort() + if len(ckpt_dirs) > 0: + args.resume = ckpt_dirs[-1] + if args.resume: + print("Resume checkpoint %s" % args.resume) + local_checkpoint_path = os.path.join( + args.resume, + f"checkpoint.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth", + ) + with load_with_process_group(fs_init.get_data_parallel_group()): + checkpoint = torch.load(local_checkpoint_path, map_location='cpu') + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + loss_scaler.load_state_dict(checkpoint['scaler']) + start_iter = int(checkpoint['iter']) + 1 + if 'epoch' in checkpoint: + start_epoch = int(checkpoint['epoch']) + return start_epoch, start_iter + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + if isinstance(x, torch.Tensor): + x_reduce = x.clone().cuda() + else: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + #if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + if name.endswith(".bias") or name.endswith("norm.weight"): + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + + + +class default_tensor_type: + _tensor_type_stack = [(torch.float, "cpu")] + + def __init__( + self, + dtype: Optional[torch.dtype] = None, + device: Optional[str] = None, + ) -> None: + # Only limited combinations are supported. + assert device is None or device in ["cpu", "cuda"] + assert dtype is None or dtype in [torch.float, torch.bfloat16, torch.half] + self.dtype, self.device = dtype, device + + def __enter__(self) -> None: + dtype, device = self.dtype, self.device + if dtype is None: + dtype = default_tensor_type._tensor_type_stack[-1][0] + if device is None: + device = default_tensor_type._tensor_type_stack[-1][1] + default_tensor_type._tensor_type_stack.append((dtype, device)) + + # We use all 3 calls since the new apis (set_default_device, set_default_dtype) + # seems to be ineffective sometimes (e.g., set_default_device is ineffective to + # torch.Tensor calls). + torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device)) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + default_tensor_type._tensor_type_stack.pop() + dtype, device = default_tensor_type._tensor_type_stack[-1] + + torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device)) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + @staticmethod + def get_tensor_type(dtype: torch.dtype, device: str) -> Any: + return { + (torch.float, "cpu"): torch.FloatTensor, + (torch.bfloat16, "cpu"): torch.BFloat16Tensor, + (torch.half, "cpu"): torch.HalfTensor, + (torch.float, "cuda"): torch.cuda.FloatTensor, + (torch.bfloat16, "cuda"): torch.cuda.BFloat16Tensor, + (torch.half, "cuda"): torch.cuda.HalfTensor, + }[(dtype, device)] + diff --git a/util/pos_embed.py b/util/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..1924913c1ffe7c73b889a4d3bad586ee8b3d2d7d --- /dev/null +++ b/util/pos_embed.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def interpolate_pos_embed_online( + pos_embed, orig_size, new_size, num_extra_tokens: int +): + # [257, 1024] + extra_tokens = pos_embed[:num_extra_tokens] + pos_tokens = pos_embed[num_extra_tokens:] + embedding_size = pos_tokens.shape[1] + pos_tokens = pos_tokens.reshape( + -1, orig_size[0], orig_size[1], embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=new_size, mode="bicubic", align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0) + return new_pos_embed \ No newline at end of file