diff --git a/configuration.py b/configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..aa94cf6a9ac531e400d2f8c8c52b338eb19a91fc --- /dev/null +++ b/configuration.py @@ -0,0 +1,183 @@ +import os + +def Root(): + models_path = "models" #@param {type:"string"} + configs_path = "configs" #@param {type:"string"} + output_path = "output" #@param {type:"string"} + mount_google_drive = False #@param {type:"boolean"} + models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"} + output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"} + + #@markdown **Model Setup** + model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"] + model_checkpoint = "v1-5-pruned-emaonly.ckpt" #@param ["custom","v1-5-pruned.ckpt","v1-5-pruned-emaonly.ckpt","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","wd-v1-3-float16.ckpt"] + custom_config_path = "" #@param {type:"string"} + custom_checkpoint_path = "" #@param {type:"string"} + half_precision = True + return locals() + + + +def DeforumAnimArgs(): + animation_mode = "3D" #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} + max_frames = 200 #@param {type:"number"} + border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'} + + #@markdown ####**Motion Parameters:** + angle = "0:(0)" #@param {type:"string"} + zoom = "0:(1.04)" #@param {type:"string"} + translation_x = "0:(0)" #@param {type:"string"} + translation_y = "0:(0)" #@param {type:"string"} + translation_z = "0:(0)" #@param {type:"string"} + rotation_3d_x = "0:(0)" #@param {type:"string"} + rotation_3d_y = "0:(0)" #@param {type:"string"} + rotation_3d_z = "0:(0)" #@param {type:"string"} + flip_2d_perspective = False #@param {type:"boolean"} + perspective_flip_theta = "0:(0)" #@param {type:"string"} + perspective_flip_phi = "0:(t%15)" #@param {type:"string"} + perspective_flip_gamma = "0:(0)" #@param {type:"string"} + perspective_flip_fv = "0:(0)" #@param {type:"string"} + noise_schedule = "0:(0.02)" #@param {type:"string"} + strength_schedule = "0:(0.65)" #@param {type:"string"} + contrast_schedule = "0:(1.0)" #@param {type:"string"} + + #@markdown ####**Coherence:** + color_coherence = "Match Frame 0 LAB" #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} + diffusion_cadence = "3" #@param ['1','2','3','4','5','6','7','8'] {type:'string'} + + #@markdown #### 3D Depth Warping + use_depth_warping = True #@param {type:"boolean"} + midas_weight = 0.3 #@param {type:"number"} + near_plane = 200 + far_plane = 10000 + fov = 40 #@param {type:"number"} + padding_mode = "border" #@param ['border', 'reflection', 'zeros'] {type:'string'} + sampling_mode = "bicubic" #@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} + save_depth_maps = False #@param {type:"boolean"} + + #@markdown ####**Video Input:** + video_init_path = "./input/video_in.mp4" #@param {type:"string"} + extract_nth_frame = 1 #@param {type:"number"} + overwrite_extracted_frames = True #@param {type:"boolean"} + use_mask_video = False #@param {type:"boolean"} + video_mask_path = "" #@param {type:"string"} + + #@markdown ####**Interpolation:** + interpolate_key_frames = False #@param {type:"boolean"} + interpolate_x_frames = 4 #@param {type:"number"} + + #@markdown ####**Resume Animation:** + resume_from_timestring = False #@param {type:"boolean"} + resume_timestring = "20220829210106" #@param {type:"string"} + return locals() + + + +def DeforumArgs(): + #@markdown **Image Settings** + W = 512 #@param + H = 512 #@param + W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 + + #@markdown **Sampling Settings** + seed = 2022 #@param + sampler = "klms" #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim", "dpm_fast", "dpm_adaptive", "dpmpp_2s_a", "dpmpp_2m"] + steps = 50 #@param + scale = 7 #@param + ddim_eta = 0.0 #@param + dynamic_threshold = None + static_threshold = None + + #@markdown **Save & Display Settings** + save_samples = True #@param {type:"boolean"} + save_settings = True #@param {type:"boolean"} + display_samples = True #@param {type:"boolean"} + save_sample_per_step = False #@param {type:"boolean"} + show_sample_per_step = False #@param {type:"boolean"} + + #@markdown **Prompt Settings** + prompt_weighting = True #@param {type:"boolean"} + normalize_prompt_weights = True #@param {type:"boolean"} + log_weighted_subprompts = False #@param {type:"boolean"} + + #@markdown **Batch Settings** + n_batch = 1 #@param + batch_name = "data" #@param {type:"string"} + filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] + seed_behavior = "iter" #@param ["iter","fixed","random"] + make_grid = False #@param {type:"boolean"} + grid_rows = 2 #@param + outdir = "./outputs" + + #@markdown **Init Settings** + use_init = False #@param {type:"boolean"} + strength = 0.0 #@param {type:"number"} + strength_0_no_init = True # Set the strength to 0 automatically when no init image is used + init_image = "" #@param {type:"string"} + # Whiter areas of the mask are areas that change more + use_mask = False #@param {type:"boolean"} + use_alpha_as_mask = False # use the alpha channel of the init image as the mask + mask_file = "" #@param {type:"string"} + invert_mask = False #@param {type:"boolean"} + # Adjust mask image, 1.0 is no adjustment. Should be positive numbers. + mask_brightness_adjust = 1.0 #@param {type:"number"} + mask_contrast_adjust = 1.0 #@param {type:"number"} + + # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding + overlay_mask = True # {type:"boolean"} + # Blur edges of final overlay mask, if used. Minimum = 0 (no blur) + mask_overlay_blur = 5 # {type:"number"} + + #@markdown **Exposure/Contrast Conditional Settings** + mean_scale = 0 #@param {type:"number"} + var_scale = 0 #@param {type:"number"} + exposure_scale = 0 #@param {type:"number"} + exposure_target = 0.5 #@param {type:"number"} + + #@markdown **Color Match Conditional Settings** + colormatch_scale = 0 #@param {type:"number"} + colormatch_image = "" #@param {type:"string"} + colormatch_n_colors = 4 #@param {type:"number"} + ignore_sat_weight = 0 #@param {type:"number"} + + #@markdown **CLIP\Aesthetics Conditional Settings** + clip_name = "ViT-L/14" #@param ['ViT-L/14', 'ViT-L/14@336px', 'ViT-B/16', 'ViT-B/32'] + clip_scale = 0 #@param {type:"number"} + aesthetics_scale = 0 #@param {type:"number"} + cutn = 1 #@param {type:"number"} + cut_pow = 0.0001 #@param {type:"number"} + + #@markdown **Other Conditional Settings** + init_mse_scale = 0 #@param {type:"number"} + init_mse_image = "" #@param {type:"string"} + + blue_scale = 1 #@param {type:"number"} + + #@markdown **Conditional Gradient Settings** + gradient_wrt = "x0_pred" #@param ["x", "x0_pred"] + gradient_add_to = "both" #@param ["cond", "uncond", "both"] + decode_method = "linear" #@param ["autoencoder","linear"] + grad_threshold_type = "dynamic" #@param ["dynamic", "static", "mean", "schedule"] + clamp_grad_threshold = 0.2 #@param {type:"number"} + clamp_start = 0.2 #@param + clamp_stop = 0.01 #@param + grad_inject_timing = list(range(1,10)) #@param + + #@markdown **Speed vs VRAM Settings** + cond_uncond_sync = True #@param {type:"boolean"} + + n_samples = 1 # doesnt do anything + precision = 'autocast' + C = 4 + f = 8 + + prompt = "" + timestring = "" + init_latent = None + init_sample = None + init_sample_raw = None + mask_sample = None + init_c = None + + return locals() + diff --git a/deforum-stable-diffusion/Deforum_Stable_Diffusion.ipynb b/deforum-stable-diffusion/Deforum_Stable_Diffusion.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4cda50d1973555256e54c8c69d7bad0099231778 --- /dev/null +++ b/deforum-stable-diffusion/Deforum_Stable_Diffusion.ipynb @@ -0,0 +1,580 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ByGXyiHZWM_q" + }, + "source": [ + "# **Deforum Stable Diffusion v0.6**\n", + "[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Bj\u00f6rn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings).\n", + "\n", + "[Quick Guide](https://docs.google.com/document/d/1RrQv7FntzOuLg4ohjRZPVL7iptIyBhwwbcEYEW2OfcI/edit?usp=sharing) to Deforum v0.6\n", + "\n", + "Notebook by [deforum](https://discord.gg/upmXXsrwZc)" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "IJjzzkKlWM_s" + }, + "source": [ + "#@markdown **NVIDIA GPU**\n", + "import subprocess, os, sys\n", + "sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + "print(f\"{sub_p_res[:-1]}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UA8-efH-WM_t" + }, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "0D2HQO-PWM_t" + }, + "source": [ + "\n", + "import subprocess, time, gc, os, sys\n", + "\n", + "def setup_environment():\n", + " print_subprocess = False\n", + " use_xformers_for_colab = True\n", + " try:\n", + " ipy = get_ipython()\n", + " except:\n", + " ipy = 'could not get_ipython'\n", + " if 'google.colab' in str(ipy):\n", + " print(\"..setting up environment\")\n", + " start_time = time.time()\n", + " all_process = [\n", + " ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n", + " ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],\n", + " ['git', 'clone', 'https://github.com/deforum-art/deforum-stable-diffusion'],\n", + " ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq','scikit-learn'],\n", + " ]\n", + " for process in all_process:\n", + " running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " if print_subprocess:\n", + " print(running)\n", + " with open('deforum-stable-diffusion/src/k_diffusion/__init__.py', 'w') as f:\n", + " f.write('')\n", + " sys.path.extend([\n", + " 'deforum-stable-diffusion/',\n", + " 'deforum-stable-diffusion/src',\n", + " ])\n", + " end_time = time.time()\n", + "\n", + " if use_xformers_for_colab:\n", + "\n", + " print(\"..installing xformers\")\n", + "\n", + " all_process = [['pip', 'install', 'triton==2.0.0.dev20220701']]\n", + " for process in all_process:\n", + " running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " if print_subprocess:\n", + " print(running)\n", + " \n", + " v_card_name = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " if 't4' in v_card_name.lower():\n", + " name_to_download = 'T4'\n", + " elif 'v100' in v_card_name.lower():\n", + " name_to_download = 'V100'\n", + " elif 'a100' in v_card_name.lower():\n", + " name_to_download = 'A100'\n", + " elif 'p100' in v_card_name.lower():\n", + " name_to_download = 'P100'\n", + " else:\n", + " print(v_card_name + ' is currently not supported with xformers flash attention in deforum!')\n", + "\n", + " x_ver = 'xformers-0.0.13.dev0-py3-none-any.whl'\n", + " x_link = 'https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/' + name_to_download + '/' + x_ver\n", + " \n", + " all_process = [\n", + " ['wget', x_link],\n", + " ['pip', 'install', x_ver],\n", + " ['mv', 'deforum-stable-diffusion/src/ldm/modules/attention.py', 'deforum-stable-diffusion/src/ldm/modules/attention_backup.py'],\n", + " ['mv', 'deforum-stable-diffusion/src/ldm/modules/attention_xformers.py', 'deforum-stable-diffusion/src/ldm/modules/attention.py']\n", + " ]\n", + "\n", + " for process in all_process:\n", + " running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " if print_subprocess:\n", + " print(running)\n", + "\n", + " print(f\"Environment set up in {end_time-start_time:.0f} seconds\")\n", + " else:\n", + " sys.path.extend([\n", + " 'src'\n", + " ])\n", + " return\n", + "\n", + "setup_environment()\n", + "\n", + "import torch\n", + "import random\n", + "import clip\n", + "from IPython import display\n", + "from types import SimpleNamespace\n", + "from helpers.save_images import get_output_folder\n", + "from helpers.settings import load_args\n", + "from helpers.render import render_animation, render_input_video, render_image_batch, render_interpolation\n", + "from helpers.model_load import make_linear_decode, load_model, get_model_output_paths\n", + "from helpers.aesthetics import load_aesthetics_model\n", + "\n", + "#@markdown **Path Setup**\n", + "\n", + "def Root():\n", + " models_path = \"models\" #@param {type:\"string\"}\n", + " configs_path = \"configs\" #@param {type:\"string\"}\n", + " output_path = \"output\" #@param {type:\"string\"}\n", + " mount_google_drive = True #@param {type:\"boolean\"}\n", + " models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n", + " output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n", + "\n", + " #@markdown **Model Setup**\n", + " model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", + " model_checkpoint = \"v1-5-pruned-emaonly.ckpt\" #@param [\"custom\",\"v1-5-pruned.ckpt\",\"v1-5-pruned-emaonly.ckpt\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\", \"robo-diffusion-v1.ckpt\",\"wd-v1-3-float16.ckpt\"]\n", + " custom_config_path = \"\" #@param {type:\"string\"}\n", + " custom_checkpoint_path = \"\" #@param {type:\"string\"}\n", + " half_precision = True\n", + " return locals()\n", + "\n", + "root = Root()\n", + "root = SimpleNamespace(**root)\n", + "\n", + "root.models_path, root.output_path = get_model_output_paths(root)\n", + "root.model, root.device = load_model(root, \n", + " load_on_run_all=True\n", + " , \n", + " check_sha256=True\n", + " )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6JxwhBwtWM_t" + }, + "source": [ + "# Settings" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "E0tJVYA4WM_u" + }, + "source": [ + "def DeforumAnimArgs():\n", + "\n", + " #@markdown ####**Animation:**\n", + " animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}\n", + " max_frames = 1000 #@param {type:\"number\"}\n", + " border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'}\n", + "\n", + " #@markdown ####**Motion Parameters:**\n", + " angle = \"0:(0)\"#@param {type:\"string\"}\n", + " zoom = \"0:(1.04)\"#@param {type:\"string\"}\n", + " translation_x = \"0:(10*sin(2*3.14*t/10))\"#@param {type:\"string\"}\n", + " translation_y = \"0:(0)\"#@param {type:\"string\"}\n", + " translation_z = \"0:(10)\"#@param {type:\"string\"}\n", + " rotation_3d_x = \"0:(0)\"#@param {type:\"string\"}\n", + " rotation_3d_y = \"0:(0)\"#@param {type:\"string\"}\n", + " rotation_3d_z = \"0:(0)\"#@param {type:\"string\"}\n", + " flip_2d_perspective = False #@param {type:\"boolean\"}\n", + " perspective_flip_theta = \"0:(0)\"#@param {type:\"string\"}\n", + " perspective_flip_phi = \"0:(t%15)\"#@param {type:\"string\"}\n", + " perspective_flip_gamma = \"0:(0)\"#@param {type:\"string\"}\n", + " perspective_flip_fv = \"0:(53)\"#@param {type:\"string\"}\n", + " noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n", + " strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n", + " contrast_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n", + "\n", + " #@markdown ####**Coherence:**\n", + " color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n", + " diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}\n", + "\n", + " #@markdown ####**3D Depth Warping:**\n", + " use_depth_warping = True #@param {type:\"boolean\"}\n", + " midas_weight = 0.3#@param {type:\"number\"}\n", + " near_plane = 200\n", + " far_plane = 10000\n", + " fov = 40#@param {type:\"number\"}\n", + " padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}\n", + " sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}\n", + " save_depth_maps = False #@param {type:\"boolean\"}\n", + "\n", + " #@markdown ####**Video Input:**\n", + " video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n", + " extract_nth_frame = 1#@param {type:\"number\"}\n", + " overwrite_extracted_frames = True #@param {type:\"boolean\"}\n", + " use_mask_video = False #@param {type:\"boolean\"}\n", + " video_mask_path ='/content/video_in.mp4'#@param {type:\"string\"}\n", + "\n", + " #@markdown ####**Interpolation:**\n", + " interpolate_key_frames = False #@param {type:\"boolean\"}\n", + " interpolate_x_frames = 4 #@param {type:\"number\"}\n", + " \n", + " #@markdown ####**Resume Animation:**\n", + " resume_from_timestring = False #@param {type:\"boolean\"}\n", + " resume_timestring = \"20220829210106\" #@param {type:\"string\"}\n", + "\n", + " return locals()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "id": "i9fly1RIWM_u" + }, + "source": [ + "prompts = [\n", + " \"a beautiful lake by Asher Brown Durand, trending on Artstation\", # the first prompt I want\n", + " \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", # the second prompt I want\n", + " #\"this prompt I don't want it I commented it out\",\n", + " #\"a nousr robot, trending on Artstation\", # use \"nousr robot\" with the robot diffusion model (see model_checkpoint setting)\n", + " #\"touhou 1girl komeiji_koishi portrait, green hair\", # waifu diffusion prompts can use danbooru tag groups (see model_checkpoint)\n", + " #\"this prompt has weights if prompt weighting enabled:2 can also do negative:-2\", # (see prompt_weighting)\n", + "]\n", + "\n", + "animation_prompts = {\n", + " 0: \"a beautiful apple, trending on Artstation\",\n", + " 20: \"a beautiful banana, trending on Artstation\",\n", + " 30: \"a beautiful coconut, trending on Artstation\",\n", + " 40: \"a beautiful durian, trending on Artstation\",\n", + "}" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "XVzhbmizWM_u" + }, + "source": [ + "#@markdown **Load Settings**\n", + "override_settings_with_file = False #@param {type:\"boolean\"}\n", + "settings_file = \"custom\" #@param [\"custom\", \"512x512_aesthetic_0.json\",\"512x512_aesthetic_1.json\",\"512x512_colormatch_0.json\",\"512x512_colormatch_1.json\",\"512x512_colormatch_2.json\",\"512x512_colormatch_3.json\"]\n", + "custom_settings_file = \"/content/drive/MyDrive/Settings.txt\"#@param {type:\"string\"}\n", + "\n", + "def DeforumArgs():\n", + " #@markdown **Image Settings**\n", + " W = 512 #@param\n", + " H = 512 #@param\n", + " W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n", + "\n", + " #@markdown **Sampling Settings**\n", + " seed = -1 #@param\n", + " sampler = 'dpmpp_2s_a' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\", \"dpm_fast\", \"dpm_adaptive\", \"dpmpp_2s_a\", \"dpmpp_2m\"]\n", + " steps = 80 #@param\n", + " scale = 7 #@param\n", + " ddim_eta = 0.0 #@param\n", + " dynamic_threshold = None\n", + " static_threshold = None \n", + "\n", + " #@markdown **Save & Display Settings**\n", + " save_samples = True #@param {type:\"boolean\"}\n", + " save_settings = True #@param {type:\"boolean\"}\n", + " display_samples = True #@param {type:\"boolean\"}\n", + " save_sample_per_step = False #@param {type:\"boolean\"}\n", + " show_sample_per_step = False #@param {type:\"boolean\"}\n", + "\n", + " #@markdown **Prompt Settings**\n", + " prompt_weighting = True #@param {type:\"boolean\"}\n", + " normalize_prompt_weights = True #@param {type:\"boolean\"}\n", + " log_weighted_subprompts = False #@param {type:\"boolean\"}\n", + "\n", + " #@markdown **Batch Settings**\n", + " n_batch = 1 #@param\n", + " batch_name = \"StableFun\" #@param {type:\"string\"}\n", + " filename_format = \"{timestring}_{index}_{prompt}.png\" #@param [\"{timestring}_{index}_{seed}.png\",\"{timestring}_{index}_{prompt}.png\"]\n", + " seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n", + " make_grid = False #@param {type:\"boolean\"}\n", + " grid_rows = 2 #@param \n", + " outdir = get_output_folder(root.output_path, batch_name)\n", + "\n", + " #@markdown **Init Settings**\n", + " use_init = False #@param {type:\"boolean\"}\n", + " strength = 0.0 #@param {type:\"number\"}\n", + " strength_0_no_init = True # Set the strength to 0 automatically when no init image is used\n", + " init_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n", + " # Whiter areas of the mask are areas that change more\n", + " use_mask = False #@param {type:\"boolean\"}\n", + " use_alpha_as_mask = False # use the alpha channel of the init image as the mask\n", + " mask_file = \"https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg\" #@param {type:\"string\"}\n", + " invert_mask = False #@param {type:\"boolean\"}\n", + " # Adjust mask image, 1.0 is no adjustment. Should be positive numbers.\n", + " mask_brightness_adjust = 1.0 #@param {type:\"number\"}\n", + " mask_contrast_adjust = 1.0 #@param {type:\"number\"}\n", + " # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding\n", + " overlay_mask = True # {type:\"boolean\"}\n", + " # Blur edges of final overlay mask, if used. Minimum = 0 (no blur)\n", + " mask_overlay_blur = 5 # {type:\"number\"}\n", + "\n", + " #@markdown **Exposure/Contrast Conditional Settings**\n", + " mean_scale = 0 #@param {type:\"number\"}\n", + " var_scale = 0 #@param {type:\"number\"}\n", + " exposure_scale = 0 #@param {type:\"number\"}\n", + " exposure_target = 0.5 #@param {type:\"number\"}\n", + "\n", + " #@markdown **Color Match Conditional Settings**\n", + " colormatch_scale = 0 #@param {type:\"number\"}\n", + " colormatch_image = \"https://www.saasdesign.io/wp-content/uploads/2021/02/palette-3-min-980x588.png\" #@param {type:\"string\"}\n", + " colormatch_n_colors = 4 #@param {type:\"number\"}\n", + " ignore_sat_weight = 0 #@param {type:\"number\"}\n", + "\n", + " #@markdown **CLIP\\Aesthetics Conditional Settings**\n", + " clip_name = 'ViT-L/14' #@param ['ViT-L/14', 'ViT-L/14@336px', 'ViT-B/16', 'ViT-B/32']\n", + " clip_scale = 0 #@param {type:\"number\"}\n", + " aesthetics_scale = 0 #@param {type:\"number\"}\n", + " cutn = 1 #@param {type:\"number\"}\n", + " cut_pow = 0.0001 #@param {type:\"number\"}\n", + "\n", + " #@markdown **Other Conditional Settings**\n", + " init_mse_scale = 0 #@param {type:\"number\"}\n", + " init_mse_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n", + "\n", + " blue_scale = 0 #@param {type:\"number\"}\n", + " \n", + " #@markdown **Conditional Gradient Settings**\n", + " gradient_wrt = 'x0_pred' #@param [\"x\", \"x0_pred\"]\n", + " gradient_add_to = 'both' #@param [\"cond\", \"uncond\", \"both\"]\n", + " decode_method = 'linear' #@param [\"autoencoder\",\"linear\"]\n", + " grad_threshold_type = 'dynamic' #@param [\"dynamic\", \"static\", \"mean\", \"schedule\"]\n", + " clamp_grad_threshold = 0.2 #@param {type:\"number\"}\n", + " clamp_start = 0.2 #@param\n", + " clamp_stop = 0.01 #@param\n", + " grad_inject_timing = list(range(1,10)) #@param\n", + "\n", + " #@markdown **Speed vs VRAM Settings**\n", + " cond_uncond_sync = True #@param {type:\"boolean\"}\n", + "\n", + " n_samples = 1 # doesnt do anything\n", + " precision = 'autocast' \n", + " C = 4\n", + " f = 8\n", + "\n", + " prompt = \"\"\n", + " timestring = \"\"\n", + " init_latent = None\n", + " init_sample = None\n", + " init_sample_raw = None\n", + " mask_sample = None\n", + " init_c = None\n", + "\n", + " return locals()\n", + "\n", + "args_dict = DeforumArgs()\n", + "anim_args_dict = DeforumAnimArgs()\n", + "\n", + "if override_settings_with_file:\n", + " load_args(args_dict, anim_args_dict, settings_file, custom_settings_file, verbose=False)\n", + "\n", + "args = SimpleNamespace(**args_dict)\n", + "anim_args = SimpleNamespace(**anim_args_dict)\n", + "\n", + "args.timestring = time.strftime('%Y%m%d%H%M%S')\n", + "args.strength = max(0.0, min(1.0, args.strength))\n", + "\n", + "# Load clip model if using clip guidance\n", + "if (args.clip_scale > 0) or (args.aesthetics_scale > 0):\n", + " root.clip_model = clip.load(args.clip_name, jit=False)[0].eval().requires_grad_(False).to(root.device)\n", + " if (args.aesthetics_scale > 0):\n", + " root.aesthetics_model = load_aesthetics_model(args, root)\n", + "\n", + "if args.seed == -1:\n", + " args.seed = random.randint(0, 2**32 - 1)\n", + "if not args.use_init:\n", + " args.init_image = None\n", + "if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):\n", + " print(f\"Init images aren't supported with PLMS yet, switching to KLMS\")\n", + " args.sampler = 'klms'\n", + "if args.sampler != 'ddim':\n", + " args.ddim_eta = 0\n", + "\n", + "if anim_args.animation_mode == 'None':\n", + " anim_args.max_frames = 1\n", + "elif anim_args.animation_mode == 'Video Input':\n", + " args.use_init = True\n", + "\n", + "# clean up unused memory\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n", + "\n", + "# dispatch to appropriate renderer\n", + "if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':\n", + " render_animation(args, anim_args, animation_prompts, root)\n", + "elif anim_args.animation_mode == 'Video Input':\n", + " render_input_video(args, anim_args, animation_prompts, root)\n", + "elif anim_args.animation_mode == 'Interpolation':\n", + " render_interpolation(args, anim_args, animation_prompts, root)\n", + "else:\n", + " render_image_batch(args, prompts, root)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gJ88kZ2-WM_v" + }, + "source": [ + "# Create Video From Frames" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "XQGeqaGAWM_v" + }, + "source": [ + "skip_video_for_run_all = True #@param {type: 'boolean'}\n", + "fps = 12 #@param {type:\"number\"}\n", + "#@markdown **Manual Settings**\n", + "use_manual_settings = False #@param {type:\"boolean\"}\n", + "image_path = \"/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png\" #@param {type:\"string\"}\n", + "mp4_path = \"/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939.mp4\" #@param {type:\"string\"}\n", + "render_steps = False #@param {type: 'boolean'}\n", + "path_name_modifier = \"x0_pred\" #@param [\"x0_pred\",\"x\"]\n", + "make_gif = False\n", + "\n", + "if skip_video_for_run_all == True:\n", + " print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n", + "else:\n", + " import os\n", + " import subprocess\n", + " from base64 import b64encode\n", + "\n", + " print(f\"{image_path} -> {mp4_path}\")\n", + "\n", + " if use_manual_settings:\n", + " max_frames = \"200\" #@param {type:\"string\"}\n", + " else:\n", + " if render_steps: # render steps from a single image\n", + " fname = f\"{path_name_modifier}_%05d.png\"\n", + " all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))]\n", + " newest_dir = max(all_step_dirs, key=os.path.getmtime)\n", + " image_path = os.path.join(newest_dir, fname)\n", + " print(f\"Reading images from {image_path}\")\n", + " mp4_path = os.path.join(newest_dir, f\"{args.timestring}_{path_name_modifier}.mp4\")\n", + " max_frames = str(args.steps)\n", + " else: # render images for a video\n", + " image_path = os.path.join(args.outdir, f\"{args.timestring}_%05d.png\")\n", + " mp4_path = os.path.join(args.outdir, f\"{args.timestring}.mp4\")\n", + " max_frames = str(anim_args.max_frames)\n", + "\n", + " # make video\n", + " cmd = [\n", + " 'ffmpeg',\n", + " '-y',\n", + " '-vcodec', 'png',\n", + " '-r', str(fps),\n", + " '-start_number', str(0),\n", + " '-i', image_path,\n", + " '-frames:v', max_frames,\n", + " '-c:v', 'libx264',\n", + " '-vf',\n", + " f'fps={fps}',\n", + " '-pix_fmt', 'yuv420p',\n", + " '-crf', '17',\n", + " '-preset', 'veryfast',\n", + " '-pattern_type', 'sequence',\n", + " mp4_path\n", + " ]\n", + " process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", + " stdout, stderr = process.communicate()\n", + " if process.returncode != 0:\n", + " print(stderr)\n", + " raise RuntimeError(stderr)\n", + "\n", + " mp4 = open(mp4_path,'rb').read()\n", + " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + " display.display(display.HTML(f'') )\n", + " \n", + " if make_gif:\n", + " gif_path = os.path.splitext(mp4_path)[0]+'.gif'\n", + " cmd_gif = [\n", + " 'ffmpeg',\n", + " '-y',\n", + " '-i', mp4_path,\n", + " '-r', str(fps),\n", + " gif_path\n", + " ]\n", + " process_gif = subprocess.Popen(cmd_gif, stdout=subprocess.PIPE, stderr=subprocess.PIPE)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "MMpAcyrYWM_v" + }, + "source": [ + "skip_disconnect_for_run_all = True #@param {type: 'boolean'}\n", + "\n", + "if skip_disconnect_for_run_all == True:\n", + " print('Skipping disconnect, uncheck skip_disconnect_for_run_all if you want to run it')\n", + "else:\n", + " from google.colab import runtime\n", + " runtime.unassign()" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.6 ('dsd')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b7e04c8a9537645cbc77fa0cbde8069bc94e341b0d5ced104651213865b24e58" + } + }, + "colab": { + "provenance": [] + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/deforum-stable-diffusion/Deforum_Stable_Diffusion.py b/deforum-stable-diffusion/Deforum_Stable_Diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..a7df9fbb64de9633eb3b85bca7c8f5950d23fc56 --- /dev/null +++ b/deforum-stable-diffusion/Deforum_Stable_Diffusion.py @@ -0,0 +1,536 @@ +# %% +# !! {"metadata":{ +# !! "id": "ByGXyiHZWM_q" +# !! }} +""" +# **Deforum Stable Diffusion v0.6** +[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). + +[Quick Guide](https://docs.google.com/document/d/1RrQv7FntzOuLg4ohjRZPVL7iptIyBhwwbcEYEW2OfcI/edit?usp=sharing) to Deforum v0.6 + +Notebook by [deforum](https://discord.gg/upmXXsrwZc) +""" + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "IJjzzkKlWM_s" +# !! }} +#@markdown **NVIDIA GPU** +import subprocess, os, sys +sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8') +print(f"{sub_p_res[:-1]}") + +# %% +# !! {"metadata":{ +# !! "id": "UA8-efH-WM_t" +# !! }} +""" +# Setup +""" + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "0D2HQO-PWM_t" +# !! }} + +import subprocess, time, gc, os, sys + +def setup_environment(): + print_subprocess = False + use_xformers_for_colab = True + try: + ipy = get_ipython() + except: + ipy = 'could not get_ipython' + if 'google.colab' in str(ipy): + print("..setting up environment") + start_time = time.time() + all_process = [ + ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'], + ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'], + ['git', 'clone', 'https://github.com/deforum-art/deforum-stable-diffusion'], + ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq','scikit-learn'], + ] + for process in all_process: + running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8') + if print_subprocess: + print(running) + with open('deforum-stable-diffusion/src/k_diffusion/__init__.py', 'w') as f: + f.write('') + sys.path.extend([ + 'deforum-stable-diffusion/', + 'deforum-stable-diffusion/src', + ]) + end_time = time.time() + + if use_xformers_for_colab: + + print("..installing xformers") + + all_process = [['pip', 'install', 'triton==2.0.0.dev20220701']] + for process in all_process: + running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8') + if print_subprocess: + print(running) + + v_card_name = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8') + if 't4' in v_card_name.lower(): + name_to_download = 'T4' + elif 'v100' in v_card_name.lower(): + name_to_download = 'V100' + elif 'a100' in v_card_name.lower(): + name_to_download = 'A100' + elif 'p100' in v_card_name.lower(): + name_to_download = 'P100' + else: + print(v_card_name + ' is currently not supported with xformers flash attention in deforum!') + + x_ver = 'xformers-0.0.13.dev0-py3-none-any.whl' + x_link = 'https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/' + name_to_download + '/' + x_ver + + all_process = [ + ['wget', x_link], + ['pip', 'install', x_ver], + ['mv', 'deforum-stable-diffusion/src/ldm/modules/attention.py', 'deforum-stable-diffusion/src/ldm/modules/attention_backup.py'], + ['mv', 'deforum-stable-diffusion/src/ldm/modules/attention_xformers.py', 'deforum-stable-diffusion/src/ldm/modules/attention.py'] + ] + + for process in all_process: + running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8') + if print_subprocess: + print(running) + + print(f"Environment set up in {end_time-start_time:.0f} seconds") + else: + sys.path.extend([ + 'src' + ]) + return + +setup_environment() + +import torch +import random +import clip +from IPython import display +from types import SimpleNamespace +from helpers.save_images import get_output_folder +from helpers.settings import load_args +from helpers.render import render_animation, render_input_video, render_image_batch, render_interpolation +from helpers.model_load import make_linear_decode, load_model, get_model_output_paths +from helpers.aesthetics import load_aesthetics_model + +#@markdown **Path Setup** + +def Root(): + models_path = "models" #@param {type:"string"} + configs_path = "configs" #@param {type:"string"} + output_path = "output" #@param {type:"string"} + mount_google_drive = True #@param {type:"boolean"} + models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"} + output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"} + + #@markdown **Model Setup** + model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"] + model_checkpoint = "v1-5-pruned-emaonly.ckpt" #@param ["custom","v1-5-pruned.ckpt","v1-5-pruned-emaonly.ckpt","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","wd-v1-3-float16.ckpt"] + custom_config_path = "" #@param {type:"string"} + custom_checkpoint_path = "" #@param {type:"string"} + half_precision = True + return locals() + +root = Root() +root = SimpleNamespace(**root) + +root.models_path, root.output_path = get_model_output_paths(root) +root.model, root.device = load_model(root, + load_on_run_all=True + , + check_sha256=True + ) + +# %% +# !! {"metadata":{ +# !! "id": "6JxwhBwtWM_t" +# !! }} +""" +# Settings +""" + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "E0tJVYA4WM_u" +# !! }} +def DeforumAnimArgs(): + + #@markdown ####**Animation:** + animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} + max_frames = 1000 #@param {type:"number"} + border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'} + + #@markdown ####**Motion Parameters:** + angle = "0:(0)"#@param {type:"string"} + zoom = "0:(1.04)"#@param {type:"string"} + translation_x = "0:(10*sin(2*3.14*t/10))"#@param {type:"string"} + translation_y = "0:(0)"#@param {type:"string"} + translation_z = "0:(10)"#@param {type:"string"} + rotation_3d_x = "0:(0)"#@param {type:"string"} + rotation_3d_y = "0:(0)"#@param {type:"string"} + rotation_3d_z = "0:(0)"#@param {type:"string"} + flip_2d_perspective = False #@param {type:"boolean"} + perspective_flip_theta = "0:(0)"#@param {type:"string"} + perspective_flip_phi = "0:(t%15)"#@param {type:"string"} + perspective_flip_gamma = "0:(0)"#@param {type:"string"} + perspective_flip_fv = "0:(53)"#@param {type:"string"} + noise_schedule = "0: (0.02)"#@param {type:"string"} + strength_schedule = "0: (0.65)"#@param {type:"string"} + contrast_schedule = "0: (1.0)"#@param {type:"string"} + + #@markdown ####**Coherence:** + color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} + diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'} + + #@markdown ####**3D Depth Warping:** + use_depth_warping = True #@param {type:"boolean"} + midas_weight = 0.3#@param {type:"number"} + near_plane = 200 + far_plane = 10000 + fov = 40#@param {type:"number"} + padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'} + sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} + save_depth_maps = False #@param {type:"boolean"} + + #@markdown ####**Video Input:** + video_init_path ='/content/video_in.mp4'#@param {type:"string"} + extract_nth_frame = 1#@param {type:"number"} + overwrite_extracted_frames = True #@param {type:"boolean"} + use_mask_video = False #@param {type:"boolean"} + video_mask_path ='/content/video_in.mp4'#@param {type:"string"} + + #@markdown ####**Interpolation:** + interpolate_key_frames = False #@param {type:"boolean"} + interpolate_x_frames = 4 #@param {type:"number"} + + #@markdown ####**Resume Animation:** + resume_from_timestring = False #@param {type:"boolean"} + resume_timestring = "20220829210106" #@param {type:"string"} + + return locals() + +# %% +# !! {"metadata":{ +# !! "id": "i9fly1RIWM_u" +# !! }} +prompts = [ + "a beautiful lake by Asher Brown Durand, trending on Artstation", # the first prompt I want + "a beautiful portrait of a woman by Artgerm, trending on Artstation", # the second prompt I want + #"this prompt I don't want it I commented it out", + #"a nousr robot, trending on Artstation", # use "nousr robot" with the robot diffusion model (see model_checkpoint setting) + #"touhou 1girl komeiji_koishi portrait, green hair", # waifu diffusion prompts can use danbooru tag groups (see model_checkpoint) + #"this prompt has weights if prompt weighting enabled:2 can also do negative:-2", # (see prompt_weighting) +] + +animation_prompts = { + 0: "a beautiful apple, trending on Artstation", + 20: "a beautiful banana, trending on Artstation", + 30: "a beautiful coconut, trending on Artstation", + 40: "a beautiful durian, trending on Artstation", +} + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "XVzhbmizWM_u" +# !! }} +#@markdown **Load Settings** +override_settings_with_file = False #@param {type:"boolean"} +settings_file = "custom" #@param ["custom", "512x512_aesthetic_0.json","512x512_aesthetic_1.json","512x512_colormatch_0.json","512x512_colormatch_1.json","512x512_colormatch_2.json","512x512_colormatch_3.json"] +custom_settings_file = "/content/drive/MyDrive/Settings.txt"#@param {type:"string"} + +def DeforumArgs(): + #@markdown **Image Settings** + W = 512 #@param + H = 512 #@param + W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 + + #@markdown **Sampling Settings** + seed = -1 #@param + sampler = 'dpmpp_2s_a' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim", "dpm_fast", "dpm_adaptive", "dpmpp_2s_a", "dpmpp_2m"] + steps = 80 #@param + scale = 7 #@param + ddim_eta = 0.0 #@param + dynamic_threshold = None + static_threshold = None + + #@markdown **Save & Display Settings** + save_samples = True #@param {type:"boolean"} + save_settings = True #@param {type:"boolean"} + display_samples = True #@param {type:"boolean"} + save_sample_per_step = False #@param {type:"boolean"} + show_sample_per_step = False #@param {type:"boolean"} + + #@markdown **Prompt Settings** + prompt_weighting = True #@param {type:"boolean"} + normalize_prompt_weights = True #@param {type:"boolean"} + log_weighted_subprompts = False #@param {type:"boolean"} + + #@markdown **Batch Settings** + n_batch = 1 #@param + batch_name = "StableFun" #@param {type:"string"} + filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] + seed_behavior = "iter" #@param ["iter","fixed","random"] + make_grid = False #@param {type:"boolean"} + grid_rows = 2 #@param + outdir = get_output_folder(root.output_path, batch_name) + + #@markdown **Init Settings** + use_init = False #@param {type:"boolean"} + strength = 0.0 #@param {type:"number"} + strength_0_no_init = True # Set the strength to 0 automatically when no init image is used + init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"} + # Whiter areas of the mask are areas that change more + use_mask = False #@param {type:"boolean"} + use_alpha_as_mask = False # use the alpha channel of the init image as the mask + mask_file = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" #@param {type:"string"} + invert_mask = False #@param {type:"boolean"} + # Adjust mask image, 1.0 is no adjustment. Should be positive numbers. + mask_brightness_adjust = 1.0 #@param {type:"number"} + mask_contrast_adjust = 1.0 #@param {type:"number"} + # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding + overlay_mask = True # {type:"boolean"} + # Blur edges of final overlay mask, if used. Minimum = 0 (no blur) + mask_overlay_blur = 5 # {type:"number"} + + #@markdown **Exposure/Contrast Conditional Settings** + mean_scale = 0 #@param {type:"number"} + var_scale = 0 #@param {type:"number"} + exposure_scale = 0 #@param {type:"number"} + exposure_target = 0.5 #@param {type:"number"} + + #@markdown **Color Match Conditional Settings** + colormatch_scale = 0 #@param {type:"number"} + colormatch_image = "https://www.saasdesign.io/wp-content/uploads/2021/02/palette-3-min-980x588.png" #@param {type:"string"} + colormatch_n_colors = 4 #@param {type:"number"} + ignore_sat_weight = 0 #@param {type:"number"} + + #@markdown **CLIP\Aesthetics Conditional Settings** + clip_name = 'ViT-L/14' #@param ['ViT-L/14', 'ViT-L/14@336px', 'ViT-B/16', 'ViT-B/32'] + clip_scale = 0 #@param {type:"number"} + aesthetics_scale = 0 #@param {type:"number"} + cutn = 1 #@param {type:"number"} + cut_pow = 0.0001 #@param {type:"number"} + + #@markdown **Other Conditional Settings** + init_mse_scale = 0 #@param {type:"number"} + init_mse_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"} + + blue_scale = 0 #@param {type:"number"} + + #@markdown **Conditional Gradient Settings** + gradient_wrt = 'x0_pred' #@param ["x", "x0_pred"] + gradient_add_to = 'both' #@param ["cond", "uncond", "both"] + decode_method = 'linear' #@param ["autoencoder","linear"] + grad_threshold_type = 'dynamic' #@param ["dynamic", "static", "mean", "schedule"] + clamp_grad_threshold = 0.2 #@param {type:"number"} + clamp_start = 0.2 #@param + clamp_stop = 0.01 #@param + grad_inject_timing = list(range(1,10)) #@param + + #@markdown **Speed vs VRAM Settings** + cond_uncond_sync = True #@param {type:"boolean"} + + n_samples = 1 # doesnt do anything + precision = 'autocast' + C = 4 + f = 8 + + prompt = "" + timestring = "" + init_latent = None + init_sample = None + init_sample_raw = None + mask_sample = None + init_c = None + + return locals() + +args_dict = DeforumArgs() +anim_args_dict = DeforumAnimArgs() + +if override_settings_with_file: + load_args(args_dict, anim_args_dict, settings_file, custom_settings_file, verbose=False) + +args = SimpleNamespace(**args_dict) +anim_args = SimpleNamespace(**anim_args_dict) + +args.timestring = time.strftime('%Y%m%d%H%M%S') +args.strength = max(0.0, min(1.0, args.strength)) + +# Load clip model if using clip guidance +if (args.clip_scale > 0) or (args.aesthetics_scale > 0): + root.clip_model = clip.load(args.clip_name, jit=False)[0].eval().requires_grad_(False).to(root.device) + if (args.aesthetics_scale > 0): + root.aesthetics_model = load_aesthetics_model(args, root) + +if args.seed == -1: + args.seed = random.randint(0, 2**32 - 1) +if not args.use_init: + args.init_image = None +if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'): + print(f"Init images aren't supported with PLMS yet, switching to KLMS") + args.sampler = 'klms' +if args.sampler != 'ddim': + args.ddim_eta = 0 + +if anim_args.animation_mode == 'None': + anim_args.max_frames = 1 +elif anim_args.animation_mode == 'Video Input': + args.use_init = True + +# clean up unused memory +gc.collect() +torch.cuda.empty_cache() + +# dispatch to appropriate renderer +if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D': + render_animation(args, anim_args, animation_prompts, root) +elif anim_args.animation_mode == 'Video Input': + render_input_video(args, anim_args, animation_prompts, root) +elif anim_args.animation_mode == 'Interpolation': + render_interpolation(args, anim_args, animation_prompts, root) +else: + render_image_batch(args, prompts, root) + +# %% +# !! {"metadata":{ +# !! "id": "gJ88kZ2-WM_v" +# !! }} +""" +# Create Video From Frames +""" + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "XQGeqaGAWM_v" +# !! }} +skip_video_for_run_all = True #@param {type: 'boolean'} +fps = 12 #@param {type:"number"} +#@markdown **Manual Settings** +use_manual_settings = False #@param {type:"boolean"} +image_path = "/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png" #@param {type:"string"} +mp4_path = "/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939.mp4" #@param {type:"string"} +render_steps = False #@param {type: 'boolean'} +path_name_modifier = "x0_pred" #@param ["x0_pred","x"] +make_gif = False + +if skip_video_for_run_all == True: + print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it') +else: + import os + import subprocess + from base64 import b64encode + + print(f"{image_path} -> {mp4_path}") + + if use_manual_settings: + max_frames = "200" #@param {type:"string"} + else: + if render_steps: # render steps from a single image + fname = f"{path_name_modifier}_%05d.png" + all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))] + newest_dir = max(all_step_dirs, key=os.path.getmtime) + image_path = os.path.join(newest_dir, fname) + print(f"Reading images from {image_path}") + mp4_path = os.path.join(newest_dir, f"{args.timestring}_{path_name_modifier}.mp4") + max_frames = str(args.steps) + else: # render images for a video + image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png") + mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4") + max_frames = str(anim_args.max_frames) + + # make video + cmd = [ + 'ffmpeg', + '-y', + '-vcodec', 'png', + '-r', str(fps), + '-start_number', str(0), + '-i', image_path, + '-frames:v', max_frames, + '-c:v', 'libx264', + '-vf', + f'fps={fps}', + '-pix_fmt', 'yuv420p', + '-crf', '17', + '-preset', 'veryfast', + '-pattern_type', 'sequence', + mp4_path + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + if process.returncode != 0: + print(stderr) + raise RuntimeError(stderr) + + mp4 = open(mp4_path,'rb').read() + data_url = "data:video/mp4;base64," + b64encode(mp4).decode() + display.display(display.HTML(f'') ) + + if make_gif: + gif_path = os.path.splitext(mp4_path)[0]+'.gif' + cmd_gif = [ + 'ffmpeg', + '-y', + '-i', mp4_path, + '-r', str(fps), + gif_path + ] + process_gif = subprocess.Popen(cmd_gif, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "MMpAcyrYWM_v" +# !! }} +skip_disconnect_for_run_all = True #@param {type: 'boolean'} + +if skip_disconnect_for_run_all == True: + print('Skipping disconnect, uncheck skip_disconnect_for_run_all if you want to run it') +else: + from google.colab import runtime + runtime.unassign() + +# %% +# !! {"main_metadata":{ +# !! "kernelspec": { +# !! "display_name": "Python 3.10.6 ('dsd')", +# !! "language": "python", +# !! "name": "python3" +# !! }, +# !! "language_info": { +# !! "codemirror_mode": { +# !! "name": "ipython", +# !! "version": 3 +# !! }, +# !! "file_extension": ".py", +# !! "mimetype": "text/x-python", +# !! "name": "python", +# !! "nbconvert_exporter": "python", +# !! "pygments_lexer": "ipython3", +# !! "version": "3.10.6" +# !! }, +# !! "orig_nbformat": 4, +# !! "vscode": { +# !! "interpreter": { +# !! "hash": "b7e04c8a9537645cbc77fa0cbde8069bc94e341b0d5ced104651213865b24e58" +# !! } +# !! }, +# !! "colab": { +# !! "provenance": [] +# !! }, +# !! "accelerator": "GPU", +# !! "gpuClass": "standard" +# !! }} diff --git a/deforum-stable-diffusion/LICENSE b/deforum-stable-diffusion/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f132f8e0eaf6ff7c6d3ea707fc22dc5e78fe0963 --- /dev/null +++ b/deforum-stable-diffusion/LICENSE @@ -0,0 +1,1806 @@ +deforum-stable-diffusion: +MIT License + +Copyright (c) 2022 deforum and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +k-diffusion: +MIT License + +Copyright (c) 2022 Katherine Crowson + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +sac: +Copyright (c) 2022 Katherine Crowson and John David Pressman + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +clip: +MIT License + +Copyright (c) 2021 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +MiDaS: +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +pytorch3d-lite: +BSD License + +For PyTorch3D software + +Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Meta nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +taming-transformers: +Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +OR OTHER DEALINGS IN THE SOFTWARE./ + +stable diffusion: +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). + +transformers: +Copyright 2018- The Hugging Face team. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +adabins: + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. + +adabins: + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/deforum-stable-diffusion/configs/v1-inference.yaml b/deforum-stable-diffusion/configs/v1-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4effe569e897369918625f9d8be5603a0e6a0d6 --- /dev/null +++ b/deforum-stable-diffusion/configs/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/deforum-stable-diffusion/configs/v2-inference-v.yaml b/deforum-stable-diffusion/configs/v2-inference-v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ec8dfbfefe94ae8522c93017668fea78d580acf --- /dev/null +++ b/deforum-stable-diffusion/configs/v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/deforum-stable-diffusion/configs/v2-inference.yaml b/deforum-stable-diffusion/configs/v2-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..152c4f3c2b36c3b246a9cb10eb8166134b0d2e1c --- /dev/null +++ b/deforum-stable-diffusion/configs/v2-inference.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/deforum-stable-diffusion/configs/v2-inpainting-inference.yaml b/deforum-stable-diffusion/configs/v2-inpainting-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32a9471d71b828c51bcbbabfe34c5f6c8282c803 --- /dev/null +++ b/deforum-stable-diffusion/configs/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/deforum-stable-diffusion/configs/v2-midas-inference.yaml b/deforum-stable-diffusion/configs/v2-midas-inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f20c30f618b81091e31c2c4cf15325fa38638af4 --- /dev/null +++ b/deforum-stable-diffusion/configs/v2-midas-inference.yaml @@ -0,0 +1,74 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + target: ldm.modules.midas.api.MiDaSInference + params: + model_type: "dpt_hybrid" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + diff --git a/deforum-stable-diffusion/configs/x4-upscaling.yaml b/deforum-stable-diffusion/configs/x4-upscaling.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2db0964af699f86d1891c761710a2d53f59b842c --- /dev/null +++ b/deforum-stable-diffusion/configs/x4-upscaling.yaml @@ -0,0 +1,76 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + parameterization: "v" + low_scale_key: "lr" + linear_start: 0.0001 + linear_end: 0.02 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 128 + channels: 4 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.08333 + use_ema: False + + low_scale_config: + target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation + params: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + diff --git a/deforum-stable-diffusion/helpers/__init__.py b/deforum-stable-diffusion/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f0b7311cbfe5594df31f2c138d0558923c033e --- /dev/null +++ b/deforum-stable-diffusion/helpers/__init__.py @@ -0,0 +1,9 @@ +""" +from .save_images import save_samples, get_output_folder +from .k_samplers import sampler_fn, make_inject_timing_fn +from .depth import DepthModel +from .prompt import sanitize +from .animation import construct_RotationMatrixHomogenous, getRotationMatrixManual, getPoints_for_PerspectiveTranformEstimation, warpMatrix, anim_frame_warp +from .generate import add_noise, load_img, load_mask_latent, prepare_mask +from .load_images import load_img, load_mask_latent, prepare_mask, prepare_overlay_mask +""" \ No newline at end of file diff --git a/deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a50aeb9d45a80e177cf5bdeefea21916713651e Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-39.pyc b/deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7dccf90b4bf1f5f8c6f6bdef748536152c79a34 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/__init__.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/aesthetics.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/aesthetics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d4a6472b03396ae145d54d5c5a16328d2028f7d Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/aesthetics.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/animation.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/animation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7a3abb0a07faebf16fe77253a36d02a98813995 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/animation.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/callback.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/callback.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15bede88c2c9c153b6eed283717a9dbf963d2bd9 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/callback.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/colors.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/colors.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba4a7d8b456278cdb00c9e0acdf807d34d8d14d Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/colors.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/conditioning.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/conditioning.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd46f87d5d47b5ee78b13345ef7c864f78788364 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/conditioning.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/depth.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/depth.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cda26f9eed6d01b8fecf8cee61d58e00e34950e Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/depth.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/generate.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/generate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0ad6d373552e6d7c31f81859fdee364a18b2d3 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/generate.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/generate.cpython-39.pyc b/deforum-stable-diffusion/helpers/__pycache__/generate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52b3d05799a1129bb16a2bc31dfcd05e8c772454 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/generate.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/k_samplers.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/k_samplers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f3269ae4b5a85ee7b83d633a621758cb6a900c Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/k_samplers.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/load_images.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/load_images.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e95616f07f3a3e44738b198ba2cb1a74fe92e72d Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/load_images.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/model_load.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/model_load.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dff8c9cb88c5a8e8cd2e804d59b9be9e34570fc Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/model_load.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/model_wrap.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/model_wrap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed13f588573c40cba8bb3815b939773a9e9da07c Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/model_wrap.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/prompt.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/prompt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c7d54dcd967c0bcb4833964f541e8892442d48b Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/prompt.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/render.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/render.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e072d4fb0fcf8d82f40c86187da7207d4f11f514 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/render.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/render.cpython-39.pyc b/deforum-stable-diffusion/helpers/__pycache__/render.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e8b2be8f640a82f7869ef92cc18fab2d7edd8a Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/render.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..200ab92455e97ad50d12f918847d0ce400ab1926 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-39.pyc b/deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96e222c34644246404415966f186838a6732ffbb Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/save_images.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/settings.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/settings.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff27462b999b942452a6b0ff953e347abb6a2d64 Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/settings.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/settings.cpython-39.pyc b/deforum-stable-diffusion/helpers/__pycache__/settings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba793744331d6633b54a5dc719cad6b49239a6df Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/settings.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/helpers/__pycache__/simulacra_fit_linear_model.cpython-38.pyc b/deforum-stable-diffusion/helpers/__pycache__/simulacra_fit_linear_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653911af7fdd8e3d0d16518dc87e9cd11ac90bda Binary files /dev/null and b/deforum-stable-diffusion/helpers/__pycache__/simulacra_fit_linear_model.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/helpers/aesthetics.py b/deforum-stable-diffusion/helpers/aesthetics.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d463b2e94c5d39bd6b2df861a9184cbd61fd5f --- /dev/null +++ b/deforum-stable-diffusion/helpers/aesthetics.py @@ -0,0 +1,48 @@ +import os +import torch +from .simulacra_fit_linear_model import AestheticMeanPredictionLinearModel +import requests + +def wget(url, outputdir): + filename = url.split("/")[-1] + + ckpt_request = requests.get(url) + request_status = ckpt_request.status_code + + # inform user of errors + if request_status == 403: + raise ConnectionRefusedError("You have not accepted the license for this model.") + elif request_status == 404: + raise ConnectionError("Could not make contact with server") + elif request_status != 200: + raise ConnectionError(f"Some other error has ocurred - response code: {request_status}") + + # write to model path + with open(os.path.join(outputdir, filename), 'wb') as model_file: + model_file.write(ckpt_request.content) + + +def load_aesthetics_model(args,root): + + clip_size = { + "ViT-B/32": 512, + "ViT-B/16": 512, + "ViT-L/14": 768, + "ViT-L/14@336px": 768, + } + + model_name = { + "ViT-B/32": "sac_public_2022_06_29_vit_b_32_linear.pth", + "ViT-B/16": "sac_public_2022_06_29_vit_b_16_linear.pth", + "ViT-L/14": "sac_public_2022_06_29_vit_l_14_linear.pth", + } + + if not os.path.exists(os.path.join(root.models_path,model_name[args.clip_name])): + print("Downloading aesthetics model...") + os.makedirs(root.models_path, exist_ok=True) + wget("https://github.com/crowsonkb/simulacra-aesthetic-models/raw/master/models/"+model_name[args.clip_name], root.models_path) + + aesthetics_model = AestheticMeanPredictionLinearModel(clip_size[args.clip_name]) + aesthetics_model.load_state_dict(torch.load(os.path.join(root.models_path,model_name[args.clip_name]))) + + return aesthetics_model.to(root.device) diff --git a/deforum-stable-diffusion/helpers/animation.py b/deforum-stable-diffusion/helpers/animation.py new file mode 100644 index 0000000000000000000000000000000000000000..27211d769b57c4a074b4fc4b8168362b4ba68ed9 --- /dev/null +++ b/deforum-stable-diffusion/helpers/animation.py @@ -0,0 +1,338 @@ +import numpy as np +import cv2 +from functools import reduce +import math +import py3d_tools as p3d +import torch +from einops import rearrange +import re +import pathlib +import os +import pandas as pd + +def check_is_number(value): + float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$' + return re.match(float_pattern, value) + +def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: + sample = ((sample.astype(float) / 255.0) * 2) - 1 + sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) + sample = torch.from_numpy(sample) + return sample + +def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray: + sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32) + sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) + sample_int8 = (sample_f32 * 255) + return sample_int8.astype(type) + +def construct_RotationMatrixHomogenous(rotation_angles): + assert(type(rotation_angles)==list and len(rotation_angles)==3) + RH = np.eye(4,4) + cv2.Rodrigues(np.array(rotation_angles), RH[0:3, 0:3]) + return RH + +def vid2frames(video_path, frames_path, n=1, overwrite=True): + if not os.path.exists(frames_path) or overwrite: + try: + for f in pathlib.Path(frames_path).glob('*.jpg'): + f.unlink() + except: + pass + assert os.path.exists(video_path), f"Video input {video_path} does not exist" + + vidcap = cv2.VideoCapture(video_path) + success,image = vidcap.read() + count = 0 + t=1 + success = True + while success: + if count % n == 0: + cv2.imwrite(frames_path + os.path.sep + f"{t:05}.jpg" , image) # save frame as JPEG file + t += 1 + success,image = vidcap.read() + count += 1 + print("Converted %d frames" % count) + else: print("Frames already unpacked") + +# https://en.wikipedia.org/wiki/Rotation_matrix +def getRotationMatrixManual(rotation_angles): + + rotation_angles = [np.deg2rad(x) for x in rotation_angles] + + phi = rotation_angles[0] # around x + gamma = rotation_angles[1] # around y + theta = rotation_angles[2] # around z + + # X rotation + Rphi = np.eye(4,4) + sp = np.sin(phi) + cp = np.cos(phi) + Rphi[1,1] = cp + Rphi[2,2] = Rphi[1,1] + Rphi[1,2] = -sp + Rphi[2,1] = sp + + # Y rotation + Rgamma = np.eye(4,4) + sg = np.sin(gamma) + cg = np.cos(gamma) + Rgamma[0,0] = cg + Rgamma[2,2] = Rgamma[0,0] + Rgamma[0,2] = sg + Rgamma[2,0] = -sg + + # Z rotation (in-image-plane) + Rtheta = np.eye(4,4) + st = np.sin(theta) + ct = np.cos(theta) + Rtheta[0,0] = ct + Rtheta[1,1] = Rtheta[0,0] + Rtheta[0,1] = -st + Rtheta[1,0] = st + + R = reduce(lambda x,y : np.matmul(x,y), [Rphi, Rgamma, Rtheta]) + + return R + +def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength): + + ptsIn2D = ptsIn[0,:] + ptsOut2D = ptsOut[0,:] + ptsOut2Dlist = [] + ptsIn2Dlist = [] + + for i in range(0,4): + ptsOut2Dlist.append([ptsOut2D[i,0], ptsOut2D[i,1]]) + ptsIn2Dlist.append([ptsIn2D[i,0], ptsIn2D[i,1]]) + + pin = np.array(ptsIn2Dlist) + [W/2.,H/2.] + pout = (np.array(ptsOut2Dlist) + [1.,1.]) * (0.5*sidelength) + pin = pin.astype(np.float32) + pout = pout.astype(np.float32) + + return pin, pout + + +def warpMatrix(W, H, theta, phi, gamma, scale, fV): + + # M is to be estimated + M = np.eye(4, 4) + + fVhalf = np.deg2rad(fV/2.) + d = np.sqrt(W*W+H*H) + sideLength = scale*d/np.cos(fVhalf) + h = d/(2.0*np.sin(fVhalf)) + n = h-(d/2.0) + f = h+(d/2.0) + + # Translation along Z-axis by -h + T = np.eye(4,4) + T[2,3] = -h + + # Rotation matrices around x,y,z + R = getRotationMatrixManual([phi, gamma, theta]) + + + # Projection Matrix + P = np.eye(4,4) + P[0,0] = 1.0/np.tan(fVhalf) + P[1,1] = P[0,0] + P[2,2] = -(f+n)/(f-n) + P[2,3] = -(2.0*f*n)/(f-n) + P[3,2] = -1.0 + + # pythonic matrix multiplication + F = reduce(lambda x,y : np.matmul(x,y), [P, T, R]) + + # shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way. + # In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3); + ptsIn = np.array([[ + [-W/2., H/2., 0.],[ W/2., H/2., 0.],[ W/2.,-H/2., 0.],[-W/2.,-H/2., 0.] + ]]) + ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype)) + ptsOut = cv2.perspectiveTransform(ptsIn, F) + + ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength) + + # check float32 otherwise OpenCV throws an error + assert(ptsInPt2f.dtype == np.float32) + assert(ptsOutPt2f.dtype == np.float32) + M33 = cv2.getPerspectiveTransform(ptsInPt2f,ptsOutPt2f) + + return M33, sideLength + +def anim_frame_warp(prev, args, anim_args, keys, frame_idx, depth_model=None, depth=None, device='cuda'): + if isinstance(prev, np.ndarray): + prev_img_cv2 = prev + else: + prev_img_cv2 = sample_to_cv2(prev) + + if anim_args.use_depth_warping: + if depth is None and depth_model is not None: + depth = depth_model.predict(prev_img_cv2, anim_args) + else: + depth = None + + if anim_args.animation_mode == '2D': + prev_img = anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx) + else: # '3D' + prev_img = anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx) + + return prev_img, depth + +def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx): + angle = keys.angle_series[frame_idx] + zoom = keys.zoom_series[frame_idx] + translation_x = keys.translation_x_series[frame_idx] + translation_y = keys.translation_y_series[frame_idx] + + center = (args.W // 2, args.H // 2) + trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) + rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) + trans_mat = np.vstack([trans_mat, [0,0,1]]) + rot_mat = np.vstack([rot_mat, [0,0,1]]) + if anim_args.flip_2d_perspective: + perspective_flip_theta = keys.perspective_flip_theta_series[frame_idx] + perspective_flip_phi = keys.perspective_flip_phi_series[frame_idx] + perspective_flip_gamma = keys.perspective_flip_gamma_series[frame_idx] + perspective_flip_fv = keys.perspective_flip_fv_series[frame_idx] + M,sl = warpMatrix(args.W, args.H, perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, 1., perspective_flip_fv); + post_trans_mat = np.float32([[1, 0, (args.W-sl)/2], [0, 1, (args.H-sl)/2]]) + post_trans_mat = np.vstack([post_trans_mat, [0,0,1]]) + bM = np.matmul(M, post_trans_mat) + xform = np.matmul(bM, rot_mat, trans_mat) + else: + xform = np.matmul(rot_mat, trans_mat) + + return cv2.warpPerspective( + prev_img_cv2, + xform, + (prev_img_cv2.shape[1], prev_img_cv2.shape[0]), + borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE + ) + +def anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx): + TRANSLATION_SCALE = 1.0/200.0 # matches Disco + translate_xyz = [ + -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE, + keys.translation_y_series[frame_idx] * TRANSLATION_SCALE, + -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE + ] + rotate_xyz = [ + math.radians(keys.rotation_3d_x_series[frame_idx]), + math.radians(keys.rotation_3d_y_series[frame_idx]), + math.radians(keys.rotation_3d_z_series[frame_idx]) + ] + rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0) + result = transform_image_3d(device, prev_img_cv2, depth, rot_mat, translate_xyz, anim_args) + torch.cuda.empty_cache() + return result + +def transform_image_3d(device, prev_img_cv2, depth_tensor, rot_mat, translate, anim_args): + # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion + w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] + + aspect_ratio = float(w)/float(h) + near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov + persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device) + persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device) + + # range of [-1,1] is important to torch grid_sample's padding handling + y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device)) + if depth_tensor is None: + z = torch.ones_like(x) + else: + z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device) + xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1) + + xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + + offset_xy = xyz_new_cam_xy - xyz_old_cam_xy + # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation. + identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0) + # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs. + coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False) + offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0) + + image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device) + new_image = torch.nn.functional.grid_sample( + image_tensor.add(1/512 - 0.0001).unsqueeze(0), + offset_coords_2d, + mode=anim_args.sampling_mode, + padding_mode=anim_args.padding_mode, + align_corners=False + ) + + # convert back to cv2 style numpy array + result = rearrange( + new_image.squeeze().clamp(0,255), + 'c h w -> h w c' + ).cpu().numpy().astype(prev_img_cv2.dtype) + return result + +class DeformAnimKeys(): + def __init__(self, anim_args): + self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames) + self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames) + self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames) + self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames) + self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames) + self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames) + self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames) + self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames) + self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames) + self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames) + self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames) + self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames) + self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames) + self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames) + self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames) + +def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'): + import numexpr + key_frame_series = pd.Series([np.nan for a in range(max_frames)]) + + for i in range(0, max_frames): + if i in key_frames: + value = key_frames[i] + value_is_number = check_is_number(value) + # if it's only a number, leave the rest for the default interpolation + if value_is_number: + t = i + key_frame_series[i] = value + if not value_is_number: + t = i + key_frame_series[i] = numexpr.evaluate(value) + key_frame_series = key_frame_series.astype(float) + + if interp_method == 'Cubic' and len(key_frames.items()) <= 3: + interp_method = 'Quadratic' + if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: + interp_method = 'Linear' + + key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] + key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] + key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both') + if integer: + return key_frame_series.astype(int) + return key_frame_series + +def parse_key_frames(string, prompt_parser=None): + # because math functions (i.e. sin(t)) can utilize brackets + # it extracts the value in form of some stuff + # which has previously been enclosed with brackets and + # with a comma or end of line existing after the closing one + pattern = r'((?P[0-9]+):[\s]*\((?P[\S\s]*?)\)([,][\s]?|[\s]?$))' + frames = dict() + for match_object in re.finditer(pattern, string): + frame = int(match_object.groupdict()['frame']) + param = match_object.groupdict()['param'] + if prompt_parser: + frames[frame] = prompt_parser(param) + else: + frames[frame] = param + if frames == {} and len(string) != 0: + raise RuntimeError('Key Frame string not correctly formatted') + return frames \ No newline at end of file diff --git a/deforum-stable-diffusion/helpers/callback.py b/deforum-stable-diffusion/helpers/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..647ac6a8cb8c98d8230b72cab595a3fbf381ffe6 --- /dev/null +++ b/deforum-stable-diffusion/helpers/callback.py @@ -0,0 +1,124 @@ +import torch +import os +import torchvision.transforms.functional as TF +from torchvision.utils import make_grid +import numpy as np +from IPython import display + +# +# Callback functions +# +class SamplerCallback(object): + # Creates the callback function to be passed into the samplers for each step + def __init__(self, args, root, mask=None, init_latent=None, sigmas=None, sampler=None, + verbose=False): + self.model = root.model + self.device = root.device + self.sampler_name = args.sampler + self.dynamic_threshold = args.dynamic_threshold + self.static_threshold = args.static_threshold + self.mask = mask + self.init_latent = init_latent + self.sigmas = sigmas + self.sampler = sampler + self.verbose = verbose + self.batch_size = args.n_samples + self.save_sample_per_step = args.save_sample_per_step + self.show_sample_per_step = args.show_sample_per_step + self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ] + + if self.save_sample_per_step: + for path in self.paths_to_image_steps: + os.makedirs(path, exist_ok=True) + + self.step_index = 0 + + self.noise = None + if init_latent is not None: + self.noise = torch.randn_like(init_latent, device=self.device) + + self.mask_schedule = None + if sigmas is not None and len(sigmas) > 0: + self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) + elif len(sigmas) == 0: + self.mask = None # no mask needed if no steps (usually happens because strength==1.0) + + if self.sampler_name in ["plms","ddim"]: + if mask is not None: + assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" + + if self.sampler_name in ["plms","ddim"]: + # Callback function formated for compvis latent diffusion samplers + self.callback = self.img_callback_ + else: + # Default callback function uses k-diffusion sampler variables + self.callback = self.k_callback_ + + self.verbose_print = print if verbose else lambda *args, **kwargs: None + + def display_images(self, images): + images = images.double().cpu().add(1).div(2).clamp(0, 1) + images = torch.tensor(np.array(images)) + grid = make_grid(images, 4).cpu() + display.clear_output(wait=True) + display.display(TF.to_pil_image(grid)) + return + + def view_sample_step(self, latents, path_name_modifier=''): + if self.save_sample_per_step: + samples = self.model.decode_first_stage(latents) + fname = f'{path_name_modifier}_{self.step_index:05}.png' + for i, sample in enumerate(samples): + sample = sample.double().cpu().add(1).div(2).clamp(0, 1) + sample = torch.tensor(np.array(sample)) + grid = make_grid(sample, 4).cpu() + TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname)) + if self.show_sample_per_step: + samples = self.model.linear_decode(latents) + print(path_name_modifier) + self.display_images(samples) + return + + # The callback function is applied to the image at each step + def dynamic_thresholding_(self, img, threshold): + # Dynamic thresholding from Imagen paper (May 2022) + s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) + s = np.max(np.append(s,1.0)) + torch.clamp_(img, -1*s, s) + torch.FloatTensor.div_(img, s) + + # Callback for samplers in the k-diffusion repo, called thus: + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + def k_callback_(self, args_dict): + self.step_index = args_dict['i'] + if self.dynamic_threshold is not None: + self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold) + if self.static_threshold is not None: + torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold) + if self.mask is not None: + init_noise = self.init_latent + self.noise * args_dict['sigma'] + is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 ) + new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) + args_dict['x'].copy_(new_img) + + self.view_sample_step(args_dict['denoised'], "x0_pred") + self.view_sample_step(args_dict['x'], "x") + + # Callback for Compvis samplers + # Function that is called on the image (img) and step (i) at each step + def img_callback_(self, img, pred_x0, i): + self.step_index = i + # Thresholding functions + if self.dynamic_threshold is not None: + self.dynamic_thresholding_(img, self.dynamic_threshold) + if self.static_threshold is not None: + torch.clamp_(img, -1*self.static_threshold, self.static_threshold) + if self.mask is not None: + i_inv = len(self.sigmas) - i - 1 + init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(self.device), noise=self.noise) + is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 ) + new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) + img.copy_(new_img) + + self.view_sample_step(pred_x0, "x0_pred") + self.view_sample_step(img, "x") \ No newline at end of file diff --git a/deforum-stable-diffusion/helpers/colors.py b/deforum-stable-diffusion/helpers/colors.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec81e197ef2b918a352d04f57337b956137b0e6 --- /dev/null +++ b/deforum-stable-diffusion/helpers/colors.py @@ -0,0 +1,16 @@ +from skimage.exposure import match_histograms +import cv2 + +def maintain_colors(prev_img, color_match_sample, mode): + if mode == 'Match Frame 0 RGB': + return match_histograms(prev_img, color_match_sample, multichannel=True) + elif mode == 'Match Frame 0 HSV': + prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) + color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) + matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) + return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) + else: # Match Frame 0 LAB + prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) + color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) + matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) + return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) \ No newline at end of file diff --git a/deforum-stable-diffusion/helpers/conditioning.py b/deforum-stable-diffusion/helpers/conditioning.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ab0db4bc1243db9c47e7efcacde92a9a5479ad --- /dev/null +++ b/deforum-stable-diffusion/helpers/conditioning.py @@ -0,0 +1,262 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import clip +from torchvision.transforms import Normalize as Normalize +from torchvision.utils import make_grid +import numpy as np +from IPython import display +from sklearn.cluster import KMeans +import torchvision.transforms.functional as TF + +### +# Loss functions +### + + +## CLIP ----------------------------------------- + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cutn, cut_pow=1.): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.cut_pow = cut_pow + + def forward(self, input): + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(self.cutn): + size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + +def make_clip_loss_fn(root, args): + clip_size = root.clip_model.visual.input_resolution # for openslip: clip_model.visual.image_size + + def parse_prompt(prompt): + if prompt.startswith('http://') or prompt.startswith('https://'): + vals = prompt.rsplit(':', 2) + vals = [vals[0] + ':' + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(':', 1) + vals = vals + ['', '1'][len(vals):] + return vals[0], float(vals[1]) + + def parse_clip_prompts(clip_prompt): + target_embeds, weights = [], [] + for prompt in clip_prompt: + txt, weight = parse_prompt(prompt) + target_embeds.append(root.clip_model.encode_text(clip.tokenize(txt).to(root.device)).float()) + weights.append(weight) + target_embeds = torch.cat(target_embeds) + weights = torch.tensor(weights, device=root.device) + if weights.sum().abs() < 1e-3: + raise RuntimeError('Clip prompt weights must not sum to 0.') + weights /= weights.sum().abs() + return target_embeds, weights + + normalize = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + + make_cutouts = MakeCutouts(clip_size, args.cutn, args.cut_pow) + target_embeds, weights = parse_clip_prompts(args.clip_prompt) + + def clip_loss_fn(x, sigma, **kwargs): + nonlocal target_embeds, weights, make_cutouts, normalize + clip_in = normalize(make_cutouts(x.add(1).div(2))) + image_embeds = root.clip_model.encode_image(clip_in).float() + dists = spherical_dist_loss(image_embeds[:, None], target_embeds[None]) + dists = dists.view([args.cutn, 1, -1]) + losses = dists.mul(weights).sum(2).mean(0) + return losses.sum() + + return clip_loss_fn + +def make_aesthetics_loss_fn(root,args): + clip_size = root.clip_model.visual.input_resolution # for openslip: clip_model.visual.image_size + + def aesthetics_cond_fn(x, sigma, **kwargs): + clip_in = F.interpolate(x, (clip_size, clip_size)) + image_embeds = root.clip_model.encode_image(clip_in).float() + losses = (10 - root.aesthetics_model(image_embeds)[0]) + return losses.sum() + + return aesthetics_cond_fn + +## end CLIP ----------------------------------------- + +# blue loss from @johnowhitaker's tutorial on Grokking Stable Diffusion +def blue_loss_fn(x, sigma, **kwargs): + # How far are the blue channel values to 0.9: + error = torch.abs(x[:,-1, :, :] - 0.9).mean() + return error + +# MSE loss from init +def make_mse_loss(target): + def mse_loss(x, sigma, **kwargs): + return (x - target).square().mean() + return mse_loss + +# MSE loss from init +def exposure_loss(target): + def exposure_loss_fn(x, sigma, **kwargs): + error = torch.abs(x-target).mean() + return error + return exposure_loss_fn + +def mean_loss_fn(x, sigma, **kwargs): + error = torch.abs(x).mean() + return error + +def var_loss_fn(x, sigma, **kwargs): + error = x.var() + return error + +def get_color_palette(root, n_colors, target, verbose=False): + def display_color_palette(color_list): + # Expand to 64x64 grid of single color pixels + images = color_list.unsqueeze(2).repeat(1,1,64).unsqueeze(3).repeat(1,1,1,64) + images = images.double().cpu().add(1).div(2).clamp(0, 1) + images = torch.tensor(np.array(images)) + grid = make_grid(images, 8).cpu() + display.display(TF.to_pil_image(grid)) + return + + # Create color palette + kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(torch.flatten(target[0],1,2).T.cpu().numpy()) + color_list = torch.Tensor(kmeans.cluster_centers_).to(root.device) + if verbose: + display_color_palette(color_list) + # Get ratio of each color class in the target image + color_indexes, color_counts = np.unique(kmeans.labels_, return_counts=True) + # color_list = color_list[color_indexes] + return color_list, color_counts + +def make_rgb_color_match_loss(root, target, n_colors, ignore_sat_weight=None, img_shape=None, device='cuda:0'): + """ + target (tensor): Image sample (values from -1 to 1) to extract the color palette + n_colors (int): Number of colors in the color palette + ignore_sat_weight (None or number>0): Scale to ignore color saturation in color comparison + img_shape (None or (int, int)): shape (width, height) of sample that the conditioning gradient is applied to, + if None then calculate target color distribution during gradient calculation + rather than once at the beginning + """ + assert n_colors > 0, "Must use at least one color with color match loss" + + def adjust_saturation(sample, saturation_factor): + # as in torchvision.transforms.functional.adjust_saturation, but for tensors with values from -1,1 + return blend(sample, TF.rgb_to_grayscale(sample), saturation_factor) + + def blend(img1, img2, ratio): + return (ratio * img1 + (1.0 - ratio) * img2).clamp(-1, 1).to(img1.dtype) + + def color_distance_distributions(n_colors, img_shape, color_list, color_counts, n_images=1): + # Get the target color distance distributions + # Ensure color counts total the amout of pixels in the image + n_pixels = img_shape[0]*img_shape[1] + color_counts = (color_counts * n_pixels / sum(color_counts)).astype(int) + + # Make color distances for each color, sorted by distance + color_distributions = torch.zeros((n_colors, n_images, n_pixels), device=device) + for i_image in range(n_images): + for ic,color0 in enumerate(color_list): + i_dist = 0 + for jc,color1 in enumerate(color_list): + color_dist = torch.linalg.norm(color0 - color1) + color_distributions[ic, i_image, i_dist:i_dist+color_counts[jc]] = color_dist + i_dist += color_counts[jc] + color_distributions, _ = torch.sort(color_distributions,dim=2) + return color_distributions + + color_list, color_counts = get_color_palette(root, n_colors, target) + color_distributions = None + if img_shape is not None: + color_distributions = color_distance_distributions(n_colors, img_shape, color_list, color_counts) + + def rgb_color_ratio_loss(x, sigma, **kwargs): + nonlocal color_distributions + all_color_norm_distances = torch.ones(len(color_list), x.shape[0], x.shape[2], x.shape[3]).to(device) * 6.0 # distance to color won't be more than max norm1 distance between -1 and 1 in 3 color dimensions + + for ic,color in enumerate(color_list): + # Make a tensor of entirely one color + color = color[None,:,None].repeat(1,1,x.shape[2]).unsqueeze(3).repeat(1,1,1,x.shape[3]) + # Get the color distances + if ignore_sat_weight is None: + # Simple color distance + color_distances = torch.linalg.norm(x - color, dim=1) + else: + # Color distance if the colors were saturated + # This is to make color comparison ignore shadows and highlights, for example + color_distances = torch.linalg.norm(adjust_saturation(x, ignore_sat_weight) - color, dim=1) + + all_color_norm_distances[ic] = color_distances + all_color_norm_distances = torch.flatten(all_color_norm_distances,start_dim=2) + + if color_distributions is None: + color_distributions = color_distance_distributions(n_colors, + (x.shape[2], x.shape[3]), + color_list, + color_counts, + n_images=x.shape[0]) + + # Sort the color distances so we can compare them as if they were a cumulative distribution function + all_color_norm_distances, _ = torch.sort(all_color_norm_distances,dim=2) + + color_norm_distribution_diff = all_color_norm_distances - color_distributions + + return color_norm_distribution_diff.square().mean() + + return rgb_color_ratio_loss + + +### +# Thresholding functions for grad +### +def threshold_by(threshold, threshold_type, clamp_schedule): + + def dynamic_thresholding(vals, sigma): + # Dynamic thresholding from Imagen paper (May 2022) + s = np.percentile(np.abs(vals.cpu()), threshold, axis=tuple(range(1,vals.ndim))) + s = np.max(np.append(s,1.0)) + vals = torch.clamp(vals, -1*s, s) + vals = torch.FloatTensor.div(vals, s) + return vals + + def static_thresholding(vals, sigma): + vals = torch.clamp(vals, -1*threshold, threshold) + return vals + + def mean_thresholding(vals, sigma): # Thresholding that appears in Jax and Disco + magnitude = vals.square().mean(axis=(1,2,3),keepdims=True).sqrt() + vals = vals * torch.where(magnitude > threshold, threshold / magnitude, 1.0) + return vals + + def scheduling(vals, sigma): + clamp_val = clamp_schedule[sigma.item()] + magnitude = vals.square().mean().sqrt() + vals = vals * magnitude.clamp(max=clamp_val) / magnitude + #print(clamp_val) + return vals + + if threshold_type == 'dynamic': + return dynamic_thresholding + elif threshold_type == 'static': + return static_thresholding + elif threshold_type == 'mean': + return mean_thresholding + elif threshold_type == 'schedule': + return scheduling + else: + raise Exception(f"Thresholding type {threshold_type} not supported") diff --git a/deforum-stable-diffusion/helpers/depth.py b/deforum-stable-diffusion/helpers/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..76d7020758eaf6a288403042191922f0f0d7beca --- /dev/null +++ b/deforum-stable-diffusion/helpers/depth.py @@ -0,0 +1,175 @@ +import cv2 +import math +import numpy as np +import os +import requests +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +from einops import rearrange, repeat +from PIL import Image + +from infer import InferenceHelper +from midas.dpt_depth import DPTDepthModel +from midas.transforms import Resize, NormalizeImage, PrepareForNet + + +def wget(url, outputdir): + filename = url.split("/")[-1] + + ckpt_request = requests.get(url) + request_status = ckpt_request.status_code + + # inform user of errors + if request_status == 403: + raise ConnectionRefusedError("You have not accepted the license for this model.") + elif request_status == 404: + raise ConnectionError("Could not make contact with server") + elif request_status != 200: + raise ConnectionError(f"Some other error has ocurred - response code: {request_status}") + + # write to model path + with open(os.path.join(outputdir, filename), 'wb') as model_file: + model_file.write(ckpt_request.content) + + +class DepthModel(): + def __init__(self, device): + self.adabins_helper = None + self.depth_min = 1000 + self.depth_max = -1000 + self.device = device + self.midas_model = None + self.midas_transform = None + + def load_adabins(self, models_path): + if not os.path.exists(os.path.join(models_path,'AdaBins_nyu.pt')): + print("Downloading AdaBins_nyu.pt...") + os.makedirs(models_path, exist_ok=True) + wget("https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt", models_path) + self.adabins_helper = InferenceHelper(models_path, dataset='nyu', device=self.device) + + def load_midas(self, models_path, half_precision=True): + if not os.path.exists(os.path.join(models_path, 'dpt_large-midas-2f21e586.pt')): + print("Downloading dpt_large-midas-2f21e586.pt...") + wget("https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", models_path) + + self.midas_model = DPTDepthModel( + path=os.path.join(models_path, "dpt_large-midas-2f21e586.pt"), + backbone="vitl16_384", + non_negative=True, + ) + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + self.midas_transform = T.Compose([ + Resize( + 384, 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet() + ]) + + self.midas_model.eval() + if half_precision and self.device == torch.device("cuda"): + self.midas_model = self.midas_model.to(memory_format=torch.channels_last) + self.midas_model = self.midas_model.half() + self.midas_model.to(self.device) + + def predict(self, prev_img_cv2, anim_args) -> torch.Tensor: + w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] + + # predict depth with AdaBins + use_adabins = anim_args.midas_weight < 1.0 and self.adabins_helper is not None + if use_adabins: + MAX_ADABINS_AREA = 500000 + MIN_ADABINS_AREA = 448*448 + + # resize image if too large or too small + img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR)) + image_pil_area = w*h + resized = True + if image_pil_area > MAX_ADABINS_AREA: + scale = math.sqrt(MAX_ADABINS_AREA) / math.sqrt(image_pil_area) + depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.LANCZOS) # LANCZOS is good for downsampling + print(f" resized to {depth_input.width}x{depth_input.height}") + elif image_pil_area < MIN_ADABINS_AREA: + scale = math.sqrt(MIN_ADABINS_AREA) / math.sqrt(image_pil_area) + depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.BICUBIC) + print(f" resized to {depth_input.width}x{depth_input.height}") + else: + depth_input = img_pil + resized = False + + # predict depth and resize back to original dimensions + try: + with torch.no_grad(): + _, adabins_depth = self.adabins_helper.predict_pil(depth_input) + if resized: + adabins_depth = TF.resize( + torch.from_numpy(adabins_depth), + torch.Size([h, w]), + interpolation=TF.InterpolationMode.BICUBIC + ) + adabins_depth = adabins_depth.cpu().numpy() + adabins_depth = adabins_depth.squeeze() + except: + print(f" exception encountered, falling back to pure MiDaS") + use_adabins = False + torch.cuda.empty_cache() + + if self.midas_model is not None: + # convert image from 0->255 uint8 to 0->1 float for feeding to MiDaS + img_midas = prev_img_cv2.astype(np.float32) / 255.0 + img_midas_input = self.midas_transform({"image": img_midas})["image"] + + # MiDaS depth estimation implementation + sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0) + if self.device == torch.device("cuda"): + sample = sample.to(memory_format=torch.channels_last) + sample = sample.half() + with torch.no_grad(): + midas_depth = self.midas_model.forward(sample) + midas_depth = torch.nn.functional.interpolate( + midas_depth.unsqueeze(1), + size=img_midas.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + midas_depth = midas_depth.cpu().numpy() + torch.cuda.empty_cache() + + # MiDaS makes the near values greater, and the far values lesser. Let's reverse that and try to align with AdaBins a bit better. + midas_depth = np.subtract(50.0, midas_depth) + midas_depth = midas_depth / 19.0 + + # blend between MiDaS and AdaBins predictions + if use_adabins: + depth_map = midas_depth*anim_args.midas_weight + adabins_depth*(1.0-anim_args.midas_weight) + else: + depth_map = midas_depth + + depth_map = np.expand_dims(depth_map, axis=0) + depth_tensor = torch.from_numpy(depth_map).squeeze().to(self.device) + else: + depth_tensor = torch.ones((h, w), device=self.device) + + return depth_tensor + + def save(self, filename: str, depth: torch.Tensor): + depth = depth.cpu().numpy() + if len(depth.shape) == 2: + depth = np.expand_dims(depth, axis=0) + self.depth_min = min(self.depth_min, depth.min()) + self.depth_max = max(self.depth_max, depth.max()) + print(f" depth min:{depth.min()} max:{depth.max()}") + denom = max(1e-8, self.depth_max - self.depth_min) + temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c') + temp = repeat(temp, 'h w 1 -> h w c', c=3) + Image.fromarray(temp.astype(np.uint8)).save(filename) + diff --git a/deforum-stable-diffusion/helpers/generate.py b/deforum-stable-diffusion/helpers/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..fc58cf1ca9fd641a508648a0c19fcb4a9dcd8bb7 --- /dev/null +++ b/deforum-stable-diffusion/helpers/generate.py @@ -0,0 +1,282 @@ +import torch +from PIL import Image +import requests +import numpy as np +import torchvision.transforms.functional as TF +from pytorch_lightning import seed_everything +import os +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddim import DDIMSampler +from k_diffusion.external import CompVisDenoiser +from torch import autocast +from contextlib import nullcontext +from einops import rearrange, repeat + +from .prompt import get_uc_and_c +from .k_samplers import sampler_fn, make_inject_timing_fn +from scipy.ndimage import gaussian_filter + +from .callback import SamplerCallback + +from .conditioning import exposure_loss, make_mse_loss, get_color_palette, make_clip_loss_fn +from .conditioning import make_rgb_color_match_loss, blue_loss_fn, threshold_by, make_aesthetics_loss_fn, mean_loss_fn, var_loss_fn, exposure_loss +from .model_wrap import CFGDenoiserWithGrad +from .load_images import load_img, load_mask_latent, prepare_mask, prepare_overlay_mask + +def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor: + return sample + torch.randn(sample.shape, device=sample.device) * noise_amt + +def generate(args, root, frame = 0, return_latent=False, return_sample=False, return_c=False): + seed_everything(args.seed) + os.makedirs(args.outdir, exist_ok=True) + + sampler = PLMSSampler(root.model) if args.sampler == 'plms' else DDIMSampler(root.model) + model_wrap = CompVisDenoiser(root.model) + batch_size = args.n_samples + prompt = args.prompt + assert prompt is not None + data = [batch_size * [prompt]] + precision_scope = autocast if args.precision == "autocast" else nullcontext + + init_latent = None + mask_image = None + init_image = None + if args.init_latent is not None: + init_latent = args.init_latent + elif args.init_sample is not None: + with precision_scope("cuda"): + init_latent = root.model.get_first_stage_encoding(root.model.encode_first_stage(args.init_sample)) + elif args.use_init and args.init_image != None and args.init_image != '': + init_image, mask_image = load_img(args.init_image, + shape=(args.W, args.H), + use_alpha_as_mask=args.use_alpha_as_mask) + init_image = init_image.to(root.device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + with precision_scope("cuda"): + init_latent = root.model.get_first_stage_encoding(root.model.encode_first_stage(init_image)) # move to latent space + + if not args.use_init and args.strength > 0 and args.strength_0_no_init: + print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.") + print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n") + args.strength = 0 + + # Mask functions + if args.use_mask: + assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel" + assert args.use_init, "use_mask==True: use_init is required for a mask" + assert init_latent is not None, "use_mask==True: An latent init image is required for a mask" + + + mask = prepare_mask(args.mask_file if mask_image is None else mask_image, + init_latent.shape, + args.mask_contrast_adjust, + args.mask_brightness_adjust, + args.invert_mask) + + if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask: + raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.") + + mask = mask.to(root.device) + mask = repeat(mask, '1 ... -> b ...', b=batch_size) + else: + mask = None + + assert not ( (args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), "Need an init image when use_mask == True and overlay_mask == True" + + # Init MSE loss image + init_mse_image = None + if args.init_mse_scale and args.init_mse_image != None and args.init_mse_image != '': + init_mse_image, mask_image = load_img(args.init_mse_image, + shape=(args.W, args.H), + use_alpha_as_mask=args.use_alpha_as_mask) + init_mse_image = init_mse_image.to(root.device) + init_mse_image = repeat(init_mse_image, '1 ... -> b ...', b=batch_size) + + assert not ( args.init_mse_scale != 0 and (args.init_mse_image is None or args.init_mse_image == '') ), "Need an init image when init_mse_scale != 0" + + t_enc = int((1.0-args.strength) * args.steps) + + # Noise schedule for the k-diffusion samplers (used for masking) + k_sigmas = model_wrap.get_sigmas(args.steps) + args.clamp_schedule = dict(zip(k_sigmas.tolist(), np.linspace(args.clamp_start,args.clamp_stop,args.steps+1))) + k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:] + + if args.sampler in ['plms','ddim']: + sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False) + + if args.colormatch_scale != 0: + assert args.colormatch_image is not None, "If using color match loss, colormatch_image is needed" + colormatch_image, _ = load_img(args.colormatch_image) + colormatch_image = colormatch_image.to('cpu') + del(_) + else: + colormatch_image = None + + # Loss functions + if args.init_mse_scale != 0: + if args.decode_method == "linear": + mse_loss_fn = make_mse_loss(root.model.linear_decode(root.model.get_first_stage_encoding(root.model.encode_first_stage(init_mse_image.to(root.device))))) + else: + mse_loss_fn = make_mse_loss(init_mse_image) + else: + mse_loss_fn = None + + if args.colormatch_scale != 0: + _,_ = get_color_palette(root, args.colormatch_n_colors, colormatch_image, verbose=True) # display target color palette outside the latent space + if args.decode_method == "linear": + grad_img_shape = (int(args.W/args.f), int(args.H/args.f)) + colormatch_image = root.model.linear_decode(root.model.get_first_stage_encoding(root.model.encode_first_stage(colormatch_image.to(root.device)))) + colormatch_image = colormatch_image.to('cpu') + else: + grad_img_shape = (args.W, args.H) + color_loss_fn = make_rgb_color_match_loss(root, + colormatch_image, + n_colors=args.colormatch_n_colors, + img_shape=grad_img_shape, + ignore_sat_weight=args.ignore_sat_weight) + else: + color_loss_fn = None + + if args.clip_scale != 0: + clip_loss_fn = make_clip_loss_fn(root, args) + else: + clip_loss_fn = None + + if args.aesthetics_scale != 0: + aesthetics_loss_fn = make_aesthetics_loss_fn(root, args) + else: + aesthetics_loss_fn = None + + if args.exposure_scale != 0: + exposure_loss_fn = exposure_loss(args.exposure_target) + else: + exposure_loss_fn = None + + loss_fns_scales = [ + [clip_loss_fn, args.clip_scale], + [blue_loss_fn, args.blue_scale], + [mean_loss_fn, args.mean_scale], + [exposure_loss_fn, args.exposure_scale], + [var_loss_fn, args.var_scale], + [mse_loss_fn, args.init_mse_scale], + [color_loss_fn, args.colormatch_scale], + [aesthetics_loss_fn, args.aesthetics_scale] + ] + + # Conditioning gradients not implemented for ddim or PLMS + assert not( any([cond_fs[1]!=0 for cond_fs in loss_fns_scales]) and (args.sampler in ["ddim","plms"]) ), "Conditioning gradients not implemented for ddim or plms. Please use a different sampler." + + callback = SamplerCallback(args=args, + root=root, + mask=mask, + init_latent=init_latent, + sigmas=k_sigmas, + sampler=sampler, + verbose=False).callback + + clamp_fn = threshold_by(threshold=args.clamp_grad_threshold, threshold_type=args.grad_threshold_type, clamp_schedule=args.clamp_schedule) + + grad_inject_timing_fn = make_inject_timing_fn(args.grad_inject_timing, model_wrap, args.steps) + + cfg_model = CFGDenoiserWithGrad(model_wrap, + loss_fns_scales, + clamp_fn, + args.gradient_wrt, + args.gradient_add_to, + args.cond_uncond_sync, + decode_method=args.decode_method, + grad_inject_timing_fn=grad_inject_timing_fn, # option to use grad in only a few of the steps + grad_consolidate_fn=None, # function to add grad to image fn(img, grad, sigma) + verbose=False) + + results = [] + with torch.no_grad(): + with precision_scope("cuda"): + with root.model.ema_scope(): + for prompts in data: + if isinstance(prompts, tuple): + prompts = list(prompts) + if args.prompt_weighting: + uc, c = get_uc_and_c(prompts, root.model, args, frame) + else: + uc = root.model.get_learned_conditioning(batch_size * [""]) + c = root.model.get_learned_conditioning(prompts) + + + if args.scale == 1.0: + uc = None + if args.init_c != None: + c = args.init_c + + if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral", "dpm_fast", "dpm_adaptive", "dpmpp_2s_a", "dpmpp_2m"]: + samples = sampler_fn( + c=c, + uc=uc, + args=args, + model_wrap=cfg_model, + init_latent=init_latent, + t_enc=t_enc, + device=root.device, + cb=callback, + verbose=False) + else: + # args.sampler == 'plms' or args.sampler == 'ddim': + if init_latent is not None and args.strength > 0: + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(root.device)) + else: + z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=root.device) + if args.sampler == 'ddim': + samples = sampler.decode(z_enc, + c, + t_enc, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + img_callback=callback) + elif args.sampler == 'plms': # no "decode" function in plms, so use "sample" + shape = [args.C, args.H // args.f, args.W // args.f] + samples, _ = sampler.sample(S=args.steps, + conditioning=c, + batch_size=args.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + eta=args.ddim_eta, + x_T=z_enc, + img_callback=callback) + else: + raise Exception(f"Sampler {args.sampler} not recognised.") + + + if return_latent: + results.append(samples.clone()) + + x_samples = root.model.decode_first_stage(samples) + + if args.use_mask and args.overlay_mask: + # Overlay the masked image after the image is generated + if args.init_sample_raw is not None: + img_original = args.init_sample_raw + elif init_image is not None: + img_original = init_image + else: + raise Exception("Cannot overlay the masked image without an init image to overlay") + + if args.mask_sample is None: + args.mask_sample = prepare_overlay_mask(args, root, img_original.shape) + + x_samples = img_original * args.mask_sample + x_samples * ((args.mask_sample * -1.0) + 1) + + if return_sample: + results.append(x_samples.clone()) + + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if return_c: + results.append(c.clone()) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + results.append(image) + return results diff --git a/deforum-stable-diffusion/helpers/k_samplers.py b/deforum-stable-diffusion/helpers/k_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..e851810d740f7f511ad3aecc4b6159eefd2c5c35 --- /dev/null +++ b/deforum-stable-diffusion/helpers/k_samplers.py @@ -0,0 +1,124 @@ +from typing import Any, Callable, Optional +from k_diffusion.external import CompVisDenoiser +from k_diffusion import sampling +import torch + + +def sampler_fn( + c: torch.Tensor, + uc: torch.Tensor, + args, + model_wrap: CompVisDenoiser, + init_latent: Optional[torch.Tensor] = None, + t_enc: Optional[torch.Tensor] = None, + device=torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda"), + cb: Callable[[Any], None] = None, + verbose: Optional[bool] = False, +) -> torch.Tensor: + shape = [args.C, args.H // args.f, args.W // args.f] + sigmas: torch.Tensor = model_wrap.get_sigmas(args.steps) + sigmas = sigmas[len(sigmas) - t_enc - 1 :] + if args.use_init: + if len(sigmas) > 0: + x = ( + init_latent + + torch.randn([args.n_samples, *shape], device=device) * sigmas[0] + ) + else: + x = init_latent + else: + if len(sigmas) > 0: + x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0] + else: + x = torch.zeros([args.n_samples, *shape], device=device) + sampler_args = { + "model": model_wrap, + "x": x, + "sigmas": sigmas, + "extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale}, + "disable": False, + "callback": cb, + } + min = sigmas[0].item() + max = min + for i in sigmas: + if i.item() < min and i.item() != 0.0: + min = i.item() + if args.sampler in ["dpm_fast"]: + sampler_args = { + "model": model_wrap, + "x": x, + "sigma_min": min, + "sigma_max": max, + "extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale}, + "disable": False, + "callback": cb, + "n":args.steps, + } + elif args.sampler in ["dpm_adaptive"]: + sampler_args = { + "model": model_wrap, + "x": x, + "sigma_min": min, + "sigma_max": max, + "extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale}, + "disable": False, + "callback": cb, + } + sampler_map = { + "klms": sampling.sample_lms, + "dpm2": sampling.sample_dpm_2, + "dpm2_ancestral": sampling.sample_dpm_2_ancestral, + "heun": sampling.sample_heun, + "euler": sampling.sample_euler, + "euler_ancestral": sampling.sample_euler_ancestral, + "dpm_fast": sampling.sample_dpm_fast, + "dpm_adaptive": sampling.sample_dpm_adaptive, + "dpmpp_2s_a": sampling.sample_dpmpp_2s_ancestral, + "dpmpp_2m": sampling.sample_dpmpp_2m, + } + + samples = sampler_map[args.sampler](**sampler_args) + return samples + + +def make_inject_timing_fn(inject_timing, model, steps): + """ + inject_timing (int or list of ints or list of floats between 0.0 and 1.0): + int: compute every inject_timing steps + list of floats: compute on these decimal fraction steps (eg, [0.5, 1.0] for 50 steps would be at steps 25 and 50) + list of ints: compute on these steps + model (CompVisDenoiser) + steps (int): number of steps + """ + all_sigmas = model.get_sigmas(steps) + target_sigmas = torch.empty([0], device=all_sigmas.device) + + def timing_fn(sigma): + is_conditioning_step = False + if sigma in target_sigmas: + is_conditioning_step = True + return is_conditioning_step + + if inject_timing is None: + timing_fn = lambda sigma: True + elif isinstance(inject_timing,int) and inject_timing <= steps and inject_timing > 0: + # Compute every nth step + target_sigma_list = [sigma for i,sigma in enumerate(all_sigmas) if (i+1) % inject_timing == 0] + target_sigmas = torch.Tensor(target_sigma_list).to(all_sigmas.device) + elif all(isinstance(t,float) for t in inject_timing) and all(t>=0.0 and t<=1.0 for t in inject_timing): + # Compute on these steps (expressed as a decimal fraction between 0.0 and 1.0) + target_indices = [int(frac_step*steps) if frac_step < 1.0 else steps-1 for frac_step in inject_timing] + target_sigma_list = [sigma for i,sigma in enumerate(all_sigmas) if i in target_indices] + target_sigmas = torch.Tensor(target_sigma_list).to(all_sigmas.device) + elif all(isinstance(t,int) for t in inject_timing) and all(t>0 and t<=steps for t in inject_timing): + # Compute on these steps + target_sigma_list = [sigma for i,sigma in enumerate(all_sigmas) if i+1 in inject_timing] + target_sigmas = torch.Tensor(target_sigma_list).to(all_sigmas.device) + + else: + raise Exception(f"Not a valid input: inject_timing={inject_timing}\n" + + f"Must be an int, list of all ints (between step 1 and {steps}), or list of all floats between 0.0 and 1.0") + return timing_fn \ No newline at end of file diff --git a/deforum-stable-diffusion/helpers/load_images.py b/deforum-stable-diffusion/helpers/load_images.py new file mode 100644 index 0000000000000000000000000000000000000000..ed3cb4f2397362cdd14faa296710b803f72896cb --- /dev/null +++ b/deforum-stable-diffusion/helpers/load_images.py @@ -0,0 +1,99 @@ +import torch +import requests +from PIL import Image +import numpy as np +import torchvision.transforms.functional as TF +from einops import repeat +from scipy.ndimage import gaussian_filter + +def load_img(path, shape=None, use_alpha_as_mask=False): + # use_alpha_as_mask: Read the alpha channel of the image as the mask image + if path.startswith('http://') or path.startswith('https://'): + image = Image.open(requests.get(path, stream=True).raw) + else: + image = Image.open(path) + + if use_alpha_as_mask: + image = image.convert('RGBA') + else: + image = image.convert('RGB') + + if shape is not None: + image = image.resize(shape, resample=Image.LANCZOS) + + mask_image = None + if use_alpha_as_mask: + # Split alpha channel into a mask_image + red, green, blue, alpha = Image.Image.split(image) + mask_image = alpha.convert('L') + image = image.convert('RGB') + + image = np.array(image).astype(np.float16) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + image = 2.*image - 1. + + return image, mask_image + +def load_mask_latent(mask_input, shape): + # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object + # shape (list-like len(4)): shape of the image to match, usually latent_image.shape + + if isinstance(mask_input, str): # mask input is probably a file name + if mask_input.startswith('http://') or mask_input.startswith('https://'): + mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA') + else: + mask_image = Image.open(mask_input).convert('RGBA') + elif isinstance(mask_input, Image.Image): + mask_image = mask_input + else: + raise Exception("mask_input must be a PIL image or a file name") + + mask_w_h = (shape[-1], shape[-2]) + mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS) + mask = mask.convert("L") + return mask + +def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0, invert_mask=False): + # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object + # shape (list-like len(4)): shape of the image to match, usually latent_image.shape + # mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge, + # 0 is black, 1 is no adjustment, >1 is brighter + # mask_contrast_adjust (non-negative float): amount to adjust contrast of the image, + # 0 is a flat grey image, 1 is no adjustment, >1 is more contrast + + mask = load_mask_latent(mask_input, mask_shape) + + # Mask brightness/contrast adjustments + if mask_brightness_adjust != 1: + mask = TF.adjust_brightness(mask, mask_brightness_adjust) + if mask_contrast_adjust != 1: + mask = TF.adjust_contrast(mask, mask_contrast_adjust) + + # Mask image to array + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask,(4,1,1)) + mask = np.expand_dims(mask,axis=0) + mask = torch.from_numpy(mask) + + if invert_mask: + mask = ( (mask - 0.5) * -1) + 0.5 + + mask = np.clip(mask,0,1) + return mask + +def prepare_overlay_mask(args, root, mask_shape): + mask_fullres = prepare_mask(args.mask_file, + mask_shape, + args.mask_contrast_adjust, + args.mask_brightness_adjust, + args.invert_mask) + mask_fullres = mask_fullres[:,:3,:,:] + mask_fullres = repeat(mask_fullres, '1 ... -> b ...', b=args.n_samples) + + mask_fullres[mask_fullres < mask_fullres.max()] = 0 + mask_fullres = gaussian_filter(mask_fullres, args.mask_overlay_blur) + mask_fullres = torch.Tensor(mask_fullres).to(root.device) + return mask_fullres + + diff --git a/deforum-stable-diffusion/helpers/model_load.py b/deforum-stable-diffusion/helpers/model_load.py new file mode 100644 index 0000000000000000000000000000000000000000..333524911801b32e44e42a2eb1de00d9c977b1a2 --- /dev/null +++ b/deforum-stable-diffusion/helpers/model_load.py @@ -0,0 +1,257 @@ +import os +import torch + +# Decodes the image without passing through the upscaler. The resulting image will be the same size as the latent +# Thanks to Kevin Turner (https://github.com/keturn) we have a shortcut to look at the decoded image! +def make_linear_decode(model_version, device='cuda:0'): + v1_4_rgb_latent_factors = [ + # R G B + [ 0.298, 0.207, 0.208], # L1 + [ 0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + + if model_version[:5] == "sd-v1": + rgb_latent_factors = torch.Tensor(v1_4_rgb_latent_factors).to(device) + else: + raise Exception(f"Model name {model_version} not recognized.") + + def linear_decode(latent): + latent_image = latent.permute(0, 2, 3, 1) @ rgb_latent_factors + latent_image = latent_image.permute(0, 3, 1, 2) + return latent_image + + return linear_decode + +def load_model(root, load_on_run_all=True, check_sha256=True): + + import requests + import torch + from ldm.util import instantiate_from_config + from omegaconf import OmegaConf + from transformers import logging + logging.set_verbosity_error() + + try: + ipy = get_ipython() + except: + ipy = 'could not get_ipython' + + if 'google.colab' in str(ipy): + path_extend = "deforum-stable-diffusion" + else: + path_extend = "" + + model_map = { + "512-base-ema.ckpt": { + 'sha256': 'd635794c1fedfdfa261e065370bea59c651fc9bfa65dc6d67ad29e11869a1824', + 'url': 'https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt', + 'requires_login': True, + }, + "v1-5-pruned.ckpt": { + 'sha256': 'e1441589a6f3c5a53f5f54d0975a18a7feb7cdf0b0dee276dfc3331ae376a053', + 'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt', + 'requires_login': True, + }, + "v1-5-pruned-emaonly.ckpt": { + 'sha256': 'cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516', + 'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt', + 'requires_login': True, + }, + "sd-v1-4-full-ema.ckpt": { + 'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-4.ckpt": { + 'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', + 'requires_login': True, + }, + "sd-v1-3-full-ema.ckpt": { + 'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-3.ckpt": { + 'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt', + 'requires_login': True, + }, + "sd-v1-2-full-ema.ckpt": { + 'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-2.ckpt": { + 'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt', + 'requires_login': True, + }, + "sd-v1-1-full-ema.ckpt": { + 'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829', + 'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-1.ckpt": { + 'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt', + 'requires_login': True, + }, + "robo-diffusion-v1.ckpt": { + 'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16', + 'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt', + 'requires_login': False, + }, + "wd-v1-3-float16.ckpt": { + 'sha256': '4afab9126057859b34d13d6207d90221d0b017b7580469ea70cee37757a29edd', + 'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt', + 'requires_login': False, + }, + } + + # config path + ckpt_config_path = root.custom_config_path if root.model_config == "custom" else os.path.join(root.configs_path, root.model_config) + + if os.path.exists(ckpt_config_path): + print(f"{ckpt_config_path} exists") + else: + print(f"Warning: {ckpt_config_path} does not exist.") + ckpt_config_path = os.path.join(path_extend,"configs",root.model_config) + print(f"Using {ckpt_config_path} instead.") + + ckpt_config_path = os.path.abspath(ckpt_config_path) + + # checkpoint path or download + ckpt_path = root.custom_checkpoint_path if root.model_checkpoint == "custom" else os.path.join(root.models_path, root.model_checkpoint) + ckpt_valid = True + + if os.path.exists(ckpt_path): + pass + elif 'url' in model_map[root.model_checkpoint]: + url = model_map[root.model_checkpoint]['url'] + + # CLI dialogue to authenticate download + if model_map[root.model_checkpoint]['requires_login']: + print("This model requires an authentication token") + print("Please ensure you have accepted the terms of service before continuing.") + + username = input("[What is your huggingface username?]: ") + token = input("[What is your huggingface token?]: ") + + _, path = url.split("https://") + + url = f"https://{username}:{token}@{path}" + + # contact server for model + print(f"..attempting to download {root.model_checkpoint}...this may take a while") + ckpt_request = requests.get(url) + request_status = ckpt_request.status_code + + # inform user of errors + if request_status == 403: + raise ConnectionRefusedError("You have not accepted the license for this model.") + elif request_status == 404: + raise ConnectionError("Could not make contact with server") + elif request_status != 200: + raise ConnectionError(f"Some other error has ocurred - response code: {request_status}") + + # write to model path + with open(os.path.join(root.models_path, root.model_checkpoint), 'wb') as model_file: + model_file.write(ckpt_request.content) + else: + print(f"Please download model checkpoint and place in {os.path.join(root.models_path, root.model_checkpoint)}") + ckpt_valid = False + + print(f"config_path: {ckpt_config_path}") + print(f"ckpt_path: {ckpt_path}") + + if check_sha256 and root.model_checkpoint != "custom" and ckpt_valid: + try: + import hashlib + print("..checking sha256") + with open(ckpt_path, "rb") as f: + bytes = f.read() + hash = hashlib.sha256(bytes).hexdigest() + del bytes + if model_map[root.model_checkpoint]["sha256"] == hash: + print("..hash is correct") + else: + print("..hash in not correct") + ckpt_valid = False + except: + print("..could not verify model integrity") + + def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True,print_flag=False): + map_location = "cuda" # ["cpu", "cuda"] + print(f"..loading model") + pl_sd = torch.load(ckpt, map_location=map_location) + if "global_step" in pl_sd: + if print_flag: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if print_flag: + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if half_precision: + model = model.half().to(device) + else: + model = model.to(device) + model.eval() + return model + + if load_on_run_all and ckpt_valid: + local_config = OmegaConf.load(f"{ckpt_config_path}") + model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=root.half_precision) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + autoencoder_version = "sd-v1" #TODO this will be different for different models + model.linear_decode = make_linear_decode(autoencoder_version, device) + + return model, device + + +def get_model_output_paths(root): + + models_path = root.models_path + output_path = root.output_path + + #@markdown **Google Drive Path Variables (Optional)** + + force_remount = False + + try: + ipy = get_ipython() + except: + ipy = 'could not get_ipython' + + if 'google.colab' in str(ipy): + if root.mount_google_drive: + from google.colab import drive # type: ignore + try: + drive_path = "/content/drive" + drive.mount(drive_path,force_remount=force_remount) + models_path = root.models_path_gdrive + output_path = root.output_path_gdrive + except: + print("..error mounting drive or with drive path variables") + print("..reverting to default path variables") + + models_path = os.path.abspath(models_path) + output_path = os.path.abspath(output_path) + os.makedirs(models_path, exist_ok=True) + os.makedirs(output_path, exist_ok=True) + + print(f"models_path: {models_path}") + print(f"output_path: {output_path}") + + return models_path, output_path diff --git a/deforum-stable-diffusion/helpers/model_wrap.py b/deforum-stable-diffusion/helpers/model_wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3b37ae88effb0cec5a38c09a78ab7ae89b1b44 --- /dev/null +++ b/deforum-stable-diffusion/helpers/model_wrap.py @@ -0,0 +1,226 @@ +from torch import nn +from k_diffusion import utils as k_utils +import torch +from k_diffusion.external import CompVisDenoiser +from torchvision.utils import make_grid +from IPython import display +from torchvision.transforms.functional import to_pil_image + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + +class CFGDenoiserWithGrad(CompVisDenoiser): + def __init__(self, model, + loss_fns_scales, # List of [cond_function, scale] pairs + clamp_func=None, # Gradient clamping function, clamp_func(grad, sigma) + gradient_wrt=None, # Calculate gradient with respect to ["x", "x0_pred", "both"] + gradient_add_to=None, # Add gradient to ["cond", "uncond", "both"] + cond_uncond_sync=True, # Calculates the cond and uncond simultaneously + decode_method=None, # Function used to decode the latent during gradient calculation + grad_inject_timing_fn=None, # Option to use grad in only a few of the steps + grad_consolidate_fn=None, # Function to add grad to image fn(img, grad, sigma) + verbose=False): + super().__init__(model.inner_model) + self.inner_model = model + self.cond_uncond_sync = cond_uncond_sync + + # Initialize gradient calculation variables + self.clamp_func = clamp_func + self.gradient_add_to = gradient_add_to + if gradient_wrt is None: + self.gradient_wrt = 'x' + self.gradient_wrt = gradient_wrt + if decode_method is None: + decode_fn = lambda x: x + elif decode_method == "autoencoder": + decode_fn = model.inner_model.differentiable_decode_first_stage + elif decode_method == "linear": + decode_fn = model.inner_model.linear_decode + self.decode_fn = decode_fn + + # Parse loss function-scale pairs + cond_fns = [] + for loss_fn,scale in loss_fns_scales: + if scale != 0: + cond_fn = self.make_cond_fn(loss_fn, scale) + else: + cond_fn = None + cond_fns += [cond_fn] + self.cond_fns = cond_fns + + if grad_inject_timing_fn is None: + self.grad_inject_timing_fn = lambda sigma: True + else: + self.grad_inject_timing_fn = grad_inject_timing_fn + if grad_consolidate_fn is None: + self.grad_consolidate_fn = lambda img, grad, sigma: img + grad * sigma + else: + self.grad_consolidate_fn = grad_consolidate_fn + + self.verbose = verbose + self.verbose_print = print if self.verbose else lambda *args, **kwargs: None + + + # General denoising model with gradient conditioning + def cond_model_fn_(self, x, sigma, inner_model=None, **kwargs): + + # inner_model: optionally use a different inner_model function or a wrapper function around inner_model, see self.forward._cfg_model + if inner_model is None: + inner_model = self.inner_model + + total_cond_grad = torch.zeros_like(x) + for cond_fn in self.cond_fns: + if cond_fn is None: continue + + # Gradient with respect to x + if self.gradient_wrt == 'x': + with torch.enable_grad(): + x = x.detach().requires_grad_() + denoised = inner_model(x, sigma, **kwargs) + cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() + + # Gradient wrt x0_pred, so save some compute: don't record grad until after denoised is calculated + elif self.gradient_wrt == 'x0_pred': + with torch.no_grad(): + denoised = inner_model(x, sigma, **kwargs) + with torch.enable_grad(): + cond_grad = cond_fn(x, sigma, denoised=denoised.detach().requires_grad_(), **kwargs).detach() + total_cond_grad += cond_grad + + total_cond_grad = torch.nan_to_num(total_cond_grad, nan=0.0, posinf=float('inf'), neginf=-float('inf')) + + # Clamp the gradient + total_cond_grad = self.clamp_grad_verbose(total_cond_grad, sigma) + + # Add gradient to the image + if self.gradient_wrt == 'x': + x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))) + cond_denoised = inner_model(x, sigma, **kwargs) + elif self.gradient_wrt == 'x0_pred': + x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))) + cond_denoised = self.grad_consolidate_fn(denoised.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)) + + return cond_denoised + + def forward(self, x, sigma, uncond, cond, cond_scale): + + def _cfg_model(x, sigma, cond, **kwargs): + # Wrapper to add denoised cond and uncond as in a cfg model + # input "cond" is both cond and uncond weights: torch.cat([uncond, cond]) + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + + denoised = self.inner_model(x_in, sigma_in, cond=cond, **kwargs) + uncond_x0, cond_x0 = denoised.chunk(2) + x0_pred = uncond_x0 + (cond_x0 - uncond_x0) * cond_scale + return x0_pred + + # Conditioning + if self.check_conditioning_schedule(sigma): + # Apply the conditioning gradient to the completed denoised (after both cond and uncond are combined into the diffused image) + if self.cond_uncond_sync: + # x0 = self.cfg_cond_model_fn_(x, sigma, uncond=uncond, cond=cond, cond_scale=cond_scale) + cond_in = torch.cat([uncond, cond]) + x0 = self.cond_model_fn_(x, sigma, cond=cond_in, inner_model=_cfg_model) + + # Calculate cond and uncond separately + else: + if self.gradient_add_to == "uncond": + uncond = self.cond_model_fn_(x, sigma, cond=uncond) + cond = self.inner_model(x, sigma, cond=cond) + x0 = uncond + (cond - uncond) * cond_scale + elif self.gradient_add_to == "cond": + uncond = self.inner_model(x, sigma, cond=uncond) + cond = self.cond_model_fn_(x, sigma, cond=cond) + x0 = uncond + (cond - uncond) * cond_scale + elif self.gradient_add_to == "both": + uncond = self.cond_model_fn_(x, sigma, cond=uncond) + cond = self.cond_model_fn_(x, sigma, cond=cond) + x0 = uncond + (cond - uncond) * cond_scale + else: + raise Exception(f"Unrecognised option for gradient_add_to: {self.gradient_add_to}") + + # No conditioning + else: + # calculate cond and uncond simultaneously + if self.cond_uncond_sync: + cond_in = torch.cat([uncond, cond]) + x0 = _cfg_model(x, sigma, cond=cond_in) + else: + uncond = self.inner_model(x, sigma, cond=uncond) + cond = self.inner_model(x, sigma, cond=cond) + x0 = uncond + (cond - uncond) * cond_scale + + return x0 + + def make_cond_fn(self, loss_fn, scale): + # Turns a loss function into a cond function that is applied to the decoded RGB sample + # loss_fn (function): func(x, sigma, denoised) -> number + # scale (number): how much this loss is applied to the image + + # Cond function with respect to x + def cond_fn(x, sigma, denoised, **kwargs): + with torch.enable_grad(): + denoised_sample = self.decode_fn(denoised).requires_grad_() + loss = loss_fn(denoised_sample, sigma, **kwargs) * scale + grad = -torch.autograd.grad(loss, x)[0] + self.verbose_print('Loss:', loss.item()) + return grad + + # Cond function with respect to x0_pred + def cond_fn_pred(x, sigma, denoised, **kwargs): + with torch.enable_grad(): + denoised_sample = self.decode_fn(denoised).requires_grad_() + loss = loss_fn(denoised_sample, sigma, **kwargs) * scale + grad = -torch.autograd.grad(loss, denoised)[0] + self.verbose_print('Loss:', loss.item()) + return grad + + if self.gradient_wrt == 'x': + return cond_fn + elif self.gradient_wrt == 'x0_pred': + return cond_fn_pred + else: + raise Exception(f"Variable gradient_wrt == {self.gradient_wrt} not recognised.") + + def clamp_grad_verbose(self, grad, sigma): + if self.clamp_func is not None: + if self.verbose: + print("Grad before clamping:") + self.display_samples(torch.abs(grad*2.0) - 1.0) + grad = self.clamp_func(grad, sigma) + if self.verbose: + print("Conditioning gradient") + self.display_samples(torch.abs(grad*2.0) - 1.0) + return grad + + def check_conditioning_schedule(self, sigma): + is_conditioning_step = False + + if (self.cond_fns is not None and + any(cond_fn is not None for cond_fn in self.cond_fns)): + # Conditioning strength != 0 + # Check if this is a conditioning step + if self.grad_inject_timing_fn(sigma): + is_conditioning_step = True + + if self.verbose: + print(f"Conditioning step for sigma={sigma}") + + return is_conditioning_step + + def display_samples(self, images): + images = images.double().cpu().add(1).div(2).clamp(0, 1) + images = torch.tensor(images.numpy()) + grid = make_grid(images, 4).cpu() + display.display(to_pil_image(grid)) + return diff --git a/deforum-stable-diffusion/helpers/prompt.py b/deforum-stable-diffusion/helpers/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb78fdbb029cc21f0b45603a6eba22e411c528a --- /dev/null +++ b/deforum-stable-diffusion/helpers/prompt.py @@ -0,0 +1,130 @@ +import re + +def sanitize(prompt): + whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ') + tmp = ''.join(filter(whitelist.__contains__, prompt)) + return tmp.replace(' ', '_') + +def check_is_number(value): + float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$' + return re.match(float_pattern, value) + +# prompt weighting with colons and number coefficients (like 'bacon:0.75 eggs:0.25') +# borrowed from https://github.com/kylewlacy/stable-diffusion/blob/0a4397094eb6e875f98f9d71193e350d859c4220/ldm/dream/conditioning.py +# and https://github.com/raefu/stable-diffusion-automatic/blob/unstablediffusion/modules/processing.py +def get_uc_and_c(prompts, model, args, frame = 0): + prompt = prompts[0] # they are the same in a batch anyway + + # get weighted sub-prompts + negative_subprompts, positive_subprompts = split_weighted_subprompts( + prompt, frame, not args.normalize_prompt_weights + ) + + uc = get_learned_conditioning(model, negative_subprompts, "", args, -1) + c = get_learned_conditioning(model, positive_subprompts, prompt, args, 1) + + return (uc, c) + +def get_learned_conditioning(model, weighted_subprompts, text, args, sign = 1): + if len(weighted_subprompts) < 1: + log_tokenization(text, model, args.log_weighted_subprompts, sign) + c = model.get_learned_conditioning(args.n_samples * [text]) + else: + c = None + for subtext, subweight in weighted_subprompts: + log_tokenization(subtext, model, args.log_weighted_subprompts, sign * subweight) + if c is None: + c = model.get_learned_conditioning(args.n_samples * [subtext]) + c *= subweight + else: + c.add_(model.get_learned_conditioning(args.n_samples * [subtext]), alpha=subweight) + + return c + +def parse_weight(match, frame = 0)->float: + import numexpr + w_raw = match.group("weight") + if w_raw == None: + return 1 + if check_is_number(w_raw): + return float(w_raw) + else: + t = frame + if len(w_raw) < 3: + print('the value inside `-characters cannot represent a math function') + return 1 + return float(numexpr.evaluate(w_raw[1:-1])) + +def normalize_prompt_weights(parsed_prompts): + if len(parsed_prompts) == 0: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print( + "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / max(len(parsed_prompts), 1) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + +def split_weighted_subprompts(text, frame = 0, skip_normalize=False): + """ + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P(( # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )|( # or + `[\S\s]*?`# a math function + )))? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + negative_prompts = [] + positive_prompts = [] + for match in re.finditer(prompt_parser, text): + w = parse_weight(match, frame) + if w < 0: + # negating the sign as we'll feed this to uc + negative_prompts.append((match.group("prompt").replace("\\:", ":"), -w)) + elif w > 0: + positive_prompts.append((match.group("prompt").replace("\\:", ":"), w)) + + if skip_normalize: + return (negative_prompts, positive_prompts) + return (normalize_prompt_weights(negative_prompts), normalize_prompt_weights(positive_prompts)) + +# shows how the prompt is tokenized +# usually tokens have '' to indicate end-of-word, +# but for readability it has been replaced with ' ' +def log_tokenization(text, model, log=False, weight=1): + if not log: + return + tokens = model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0, totalTokens): + token = tokens[i].replace('', ' ') + # alternate color + s = (usedTokens % 6) + 1 + if i < model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") + if discarded != "": + print( + f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" + ) \ No newline at end of file diff --git a/deforum-stable-diffusion/helpers/rank_images.py b/deforum-stable-diffusion/helpers/rank_images.py new file mode 100644 index 0000000000000000000000000000000000000000..35a1994e5051d52c0df8244e74f0d644093b60c4 --- /dev/null +++ b/deforum-stable-diffusion/helpers/rank_images.py @@ -0,0 +1,69 @@ +import os +from argparse import ArgumentParser +from tqdm import tqdm +from PIL import Image +from torch.nn import functional as F +from torchvision import transforms +from torchvision.transforms import functional as TF +import torch +from simulacra_fit_linear_model import AestheticMeanPredictionLinearModel +from CLIP import clip + +parser = ArgumentParser() +parser.add_argument("directory") +parser.add_argument("-t", "--top-n", default=50) +args = parser.parse_args() + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +clip_model_name = 'ViT-B/16' +clip_model = clip.load(clip_model_name, jit=False, device=device)[0] +clip_model.eval().requires_grad_(False) + +normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + +# 512 is embed dimension for ViT-B/16 CLIP +model = AestheticMeanPredictionLinearModel(512) +model.load_state_dict( + torch.load("models/sac_public_2022_06_29_vit_b_16_linear.pth") +) +model = model.to(device) + +def get_filepaths(parentpath, filepaths): + paths = [] + for path in filepaths: + try: + new_parent = os.path.join(parentpath, path) + paths += get_filepaths(new_parent, os.listdir(new_parent)) + except NotADirectoryError: + paths.append(os.path.join(parentpath, path)) + return paths + +filepaths = get_filepaths(args.directory, os.listdir(args.directory)) +scores = [] +for path in tqdm(filepaths): + # This is obviously a flawed way to check for an image but this is just + # a demo script anyway. + if path[-4:] not in (".png", ".jpg"): + continue + img = Image.open(path).convert('RGB') + img = TF.resize(img, 224, transforms.InterpolationMode.LANCZOS) + img = TF.center_crop(img, (224,224)) + img = TF.to_tensor(img).to(device) + img = normalize(img) + clip_image_embed = F.normalize( + clip_model.encode_image(img[None, ...]).float(), + dim=-1) + score = model(clip_image_embed) + if len(scores) < args.top_n: + scores.append((score.item(),path)) + scores.sort() + else: + if scores[0][0] < score: + scores.append((score.item(),path)) + scores.sort(key=lambda x: x[0]) + scores = scores[1:] + +for score, path in scores: + print(f"{score}: {path}") diff --git a/deforum-stable-diffusion/helpers/render.py b/deforum-stable-diffusion/helpers/render.py new file mode 100644 index 0000000000000000000000000000000000000000..c2961ffae42fa9e3ffe84d952fa5c776b45c657d --- /dev/null +++ b/deforum-stable-diffusion/helpers/render.py @@ -0,0 +1,472 @@ +import os +import json +from IPython import display +import random +from torchvision.utils import make_grid +from einops import rearrange +import pandas as pd +import cv2 +import numpy as np +from PIL import Image +import pathlib +import torchvision.transforms as T + +from .generate import generate, add_noise +from .prompt import sanitize +from .animation import DeformAnimKeys, sample_from_cv2, sample_to_cv2, anim_frame_warp, vid2frames +from .depth import DepthModel +from .colors import maintain_colors +from .load_images import prepare_overlay_mask + +def next_seed(args): + if args.seed_behavior == 'iter': + args.seed += 1 + elif args.seed_behavior == 'fixed': + pass # always keep seed the same + else: + args.seed = random.randint(0, 2**32 - 1) + return args.seed + +def render_image_batch(args, prompts, root): + args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)} + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + if args.save_settings or args.save_samples: + print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") + + # save settings for the batch + if args.save_settings: + filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(filename, "w+", encoding="utf-8") as f: + dictlist = dict(args.__dict__) + del dictlist['master_args'] + del dictlist['root'] + del dictlist['get_output_folder'] + json.dump(dictlist, f, ensure_ascii=False, indent=4) + + index = 0 + + # function for init image batching + init_array = [] + if args.use_init: + if args.init_image == "": + raise FileNotFoundError("No path was given for init_image") + if args.init_image.startswith('http://') or args.init_image.startswith('https://'): + init_array.append(args.init_image) + elif not os.path.isfile(args.init_image): + if args.init_image[-1] != "/": # avoids path error by adding / to end if not there + args.init_image += "/" + for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array + if image.split(".")[-1] in ("png", "jpg", "jpeg"): + init_array.append(args.init_image + image) + else: + init_array.append(args.init_image) + else: + init_array = [""] + + # when doing large batches don't flood browser with images + clear_between_batches = args.n_batch >= 32 + + for iprompt, prompt in enumerate(prompts): + args.prompt = prompt + args.clip_prompt = prompt + print(f"Prompt {iprompt+1} of {len(prompts)}") + print(f"{args.prompt}") + + all_images = [] + + for batch_index in range(args.n_batch): + if clear_between_batches and batch_index % 32 == 0: + display.clear_output(wait=True) + print(f"Batch {batch_index+1} of {args.n_batch}") + + for image in init_array: # iterates the init images + args.init_image = image + results = generate(args, root) + for image in results: + if args.make_grid: + all_images.append(T.functional.pil_to_tensor(image)) + if args.save_samples: + if args.filename_format == "{timestring}_{index}_{prompt}.png": + filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png" + else: + filename = f"{args.timestring}_{index:05}_{args.seed}.png" + image.save(os.path.join(args.outdir, filename)) + if args.display_samples: + display.display(image) + index += 1 + args.seed = next_seed(args) + + #print(len(all_images)) + if args.make_grid: + grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows)) + grid = rearrange(grid, 'c h w -> h w c').cpu().numpy() + filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" + grid_image = Image.fromarray(grid.astype(np.uint8)) + grid_image.save(os.path.join(args.outdir, filename)) + display.clear_output(wait=True) + display.display(grid_image) + + +def render_animation(args, anim_args, animation_prompts, root): + # animations use key framed prompts + args.prompts = animation_prompts + + # expand key frame strings to values + keys = DeformAnimKeys(anim_args) + + # resume animation + start_frame = 0 + if anim_args.resume_from_timestring: + for tmp in os.listdir(args.outdir): + if tmp.split("_")[0] == anim_args.resume_timestring: + start_frame += 1 + start_frame = start_frame - 1 + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + ''' + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + s = {**dict(args.__dict__), **dict(anim_args.__dict__)} + #DGSpitzer: run.py adds these three parameters + del s['master_args'] + del s['opt'] + del s['root'] + del s['get_output_folder'] + #print(s) + json.dump(s, f, ensure_ascii=False, indent=4) + ''' + # resume from timestring + if anim_args.resume_from_timestring: + args.timestring = anim_args.resume_timestring + + # expand prompts out to per-frame + prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) + for i, prompt in animation_prompts.items(): + prompt_series[int(i)] = prompt + prompt_series = prompt_series.ffill().bfill() + + # check for video inits + using_vid_init = anim_args.animation_mode == 'Video Input' + + # load depth model for 3D + predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps + if predict_depths: + depth_model = DepthModel(root.device) + depth_model.load_midas(root.models_path) + if anim_args.midas_weight < 1.0: + depth_model.load_adabins(root.models_path) + else: + depth_model = None + anim_args.save_depth_maps = False + + # state for interpolating between diffusion steps + turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) + turbo_prev_image, turbo_prev_frame_idx = None, 0 + turbo_next_image, turbo_next_frame_idx = None, 0 + + # resume animation + prev_sample = None + color_match_sample = None + if anim_args.resume_from_timestring: + last_frame = start_frame-1 + if turbo_steps > 1: + last_frame -= last_frame%turbo_steps + path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + prev_sample = sample_from_cv2(img) + if anim_args.color_coherence != 'None': + color_match_sample = img + if turbo_steps > 1: + turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + start_frame = last_frame+turbo_steps + + args.n_samples = 1 + frame_idx = start_frame + while frame_idx < anim_args.max_frames: + print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") + noise = keys.noise_schedule_series[frame_idx] + strength = keys.strength_schedule_series[frame_idx] + contrast = keys.contrast_schedule_series[frame_idx] + depth = None + + # emit in-between frames + if turbo_steps > 1: + tween_frame_start_idx = max(0, frame_idx-turbo_steps) + for tween_frame_idx in range(tween_frame_start_idx, frame_idx): + tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) + print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}") + + advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx + advance_next = tween_frame_idx > turbo_next_frame_idx + + if depth_model is not None: + assert(turbo_next_image is not None) + depth = depth_model.predict(turbo_next_image, anim_args) + + if advance_prev: + turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) + if advance_next: + turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) + # Transformed raw image before color coherence and noise. Used for mask overlay + if args.use_mask and args.overlay_mask: + # Apply transforms to the original image + init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) + if root.half_precision: + args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) + else: + args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) + + #Transform the mask image + if args.use_mask: + if args.mask_sample is None: + args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) + # Transform the mask + mask_image, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) + if root.half_precision: + args.mask_sample = sample_from_cv2(mask_image).half().to(root.device) + else: + args.mask_sample = sample_from_cv2(mask_image).to(root.device) + + turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx + + if turbo_prev_image is not None and tween < 1.0: + img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween + else: + img = turbo_next_image + + filename = f"{args.timestring}_{tween_frame_idx:05}.png" + cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) + if anim_args.save_depth_maps: + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) + if turbo_next_image is not None: + prev_sample = sample_from_cv2(turbo_next_image) + + # apply transforms to previous frame + if prev_sample is not None: + prev_img, depth = anim_frame_warp(prev_sample, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device) + + # Transformed raw image before color coherence and noise. Used for mask overlay + if args.use_mask and args.overlay_mask: + # Apply transforms to the original image + init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) + + if root.half_precision: + args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) + else: + args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) + + #Transform the mask image + if args.use_mask: + if args.mask_sample is None: + args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) + # Transform the mask + mask_sample, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) + + if root.half_precision: + args.mask_sample = sample_from_cv2(mask_sample).half().to(root.device) + else: + args.mask_sample = sample_from_cv2(mask_sample).to(root.device) + + # apply color matching + if anim_args.color_coherence != 'None': + if color_match_sample is None: + color_match_sample = prev_img.copy() + else: + prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) + + # apply scaling + contrast_sample = prev_img * contrast + # apply frame noising + noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) + + # use transformed previous frame as init for current + args.use_init = True + if root.half_precision: + args.init_sample = noised_sample.half().to(root.device) + else: + args.init_sample = noised_sample.to(root.device) + args.strength = max(0.0, min(1.0, strength)) + + # grab prompt for current frame + args.prompt = prompt_series[frame_idx] + args.clip_prompt = args.prompt + print(f"{args.prompt} {args.seed}") + if not using_vid_init: + print(f"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}") + print(f"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}") + print(f"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}") + + # grab init image for current frame + if using_vid_init: + init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:05}.jpg") + print(f"Using video init frame {init_frame}") + args.init_image = init_frame + if anim_args.use_mask_video: + mask_frame = os.path.join(args.outdir, 'maskframes', f"{frame_idx+1:05}.jpg") + args.mask_file = mask_frame + + # sample the diffusion model + sample, image = generate(args, root, frame_idx, return_latent=False, return_sample=True) + # First image sample used for masking + if not using_vid_init: + prev_sample = sample + if args.use_mask and args.overlay_mask: + if args.init_sample_raw is None: + args.init_sample_raw = sample + + if turbo_steps > 1: + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx + frame_idx += turbo_steps + else: + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + if anim_args.save_depth_maps: + depth = depth_model.predict(sample_to_cv2(sample), anim_args) + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + +def render_input_video(args, anim_args, animation_prompts, root): + # create a folder for the video input frames to live in + video_in_frame_path = os.path.join(args.outdir, 'inputframes') + os.makedirs(video_in_frame_path, exist_ok=True) + + # save the video frames from input video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") + vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) + + # determine max frames from length of input frames + anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) + args.use_init = True + print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") + + if anim_args.use_mask_video: + # create a folder for the mask video input frames to live in + mask_in_frame_path = os.path.join(args.outdir, 'maskframes') + os.makedirs(mask_in_frame_path, exist_ok=True) + + # save the video frames from mask video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") + vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) + args.use_mask = True + args.overlay_mask = True + + render_animation(args, anim_args, animation_prompts, root) + +def render_interpolation(args, anim_args, animation_prompts, root): + # animations use key framed prompts + args.prompts = animation_prompts + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + s = {**dict(args.__dict__), **dict(anim_args.__dict__)} + del s['master_args'] + del s['opt'] + del s['root'] + del s['get_output_folder'] + json.dump(s, f, ensure_ascii=False, indent=4) + + # Interpolation Settings + args.n_samples = 1 + args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available + prompts_c_s = [] # cache all the text embeddings + + print(f"Preparing for interpolation of the following...") + + for i, prompt in animation_prompts.items(): + args.prompt = prompt + args.clip_prompt = args.prompt + + # sample the diffusion model + results = generate(args, root, return_c=True) + c, image = results[0], results[1] + prompts_c_s.append(c) + + # display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + display.clear_output(wait=True) + print(f"Interpolation start...") + + frame_idx = 0 + + if anim_args.interpolate_key_frames: + for i in range(len(prompts_c_s)-1): + dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0] + if dist_frames <= 0: + print("key frames duplicated or reversed. interpolation skipped.") + return + else: + for j in range(dist_frames): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i+1] + args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames)) + + # sample the diffusion model + results = generate(args, root) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + else: + for i in range(len(prompts_c_s)-1): + for j in range(anim_args.interpolate_x_frames+1): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i+1] + args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1))) + + # sample the diffusion model + results = generate(args, root) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + # generate the last prompt + args.init_c = prompts_c_s[-1] + results = generate(args, root) + image = results[0] + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + + display.clear_output(wait=True) + display.display(image) + args.seed = next_seed(args) + + #clear init_c + args.init_c = None diff --git a/deforum-stable-diffusion/helpers/save_images.py b/deforum-stable-diffusion/helpers/save_images.py new file mode 100644 index 0000000000000000000000000000000000000000..c96c12f50184c67e6744a4060004cd5bb9e88071 --- /dev/null +++ b/deforum-stable-diffusion/helpers/save_images.py @@ -0,0 +1,60 @@ +from typing import List, Tuple +from einops import rearrange +import numpy as np, os, torch +from PIL import Image +from torchvision.utils import make_grid +import time + + +def get_output_folder(output_path, batch_folder): + out_path = os.path.join(output_path,time.strftime('%Y-%m')) + if batch_folder != "": + out_path = os.path.join(out_path, batch_folder) + os.makedirs(out_path, exist_ok=True) + return out_path + + +def save_samples( + args, x_samples: torch.Tensor, seed: int, n_rows: int +) -> Tuple[Image.Image, List[Image.Image]]: + """Function to save samples to disk. + Args: + args: Stable deforum diffusion arguments. + x_samples: Samples to save. + seed: Seed for the experiment. + n_rows: Number of rows in the grid. + Returns: + A tuple of the grid image and a list of the generated images. + ( grid_image, generated_images ) + """ + + # save samples + images = [] + grid_image = None + if args.display_samples or args.save_samples: + for index, x_sample in enumerate(x_samples): + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + images.append(Image.fromarray(x_sample.astype(np.uint8))) + if args.save_samples: + images[-1].save( + os.path.join( + args.outdir, f"{args.timestring}_{index:02}_{seed}.png" + ) + ) + + # save grid + if args.display_grid or args.save_grid: + grid = torch.stack([x_samples], 0) + grid = rearrange(grid, "n b c h w -> (n b) c h w") + grid = make_grid(grid, nrow=n_rows, padding=0) + + # to image + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + grid_image = Image.fromarray(grid.astype(np.uint8)) + if args.save_grid: + grid_image.save( + os.path.join(args.outdir, f"{args.timestring}_{seed}_grid.png") + ) + + # return grid_image and individual sample images + return grid_image, images diff --git a/deforum-stable-diffusion/helpers/settings.py b/deforum-stable-diffusion/helpers/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..39abf235551ab93135b8d74d022d281e90403f9d --- /dev/null +++ b/deforum-stable-diffusion/helpers/settings.py @@ -0,0 +1,34 @@ +import os +import json + +def load_args(args_dict, anim_args_dict, settings_file, custom_settings_file, verbose=True): + default_settings_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'settings')) + if settings_file.lower() == 'custom': + settings_filename = custom_settings_file + else: + settings_filename = os.path.join(default_settings_dir,settings_file) + print(f"Reading custom settings from {settings_filename}...") + if not os.path.isfile(settings_filename): + print('The settings file does not exist. The in-notebook settings will be used instead.') + else: + if not verbose: + print(f"Any settings not included in {settings_filename} will use the in-notebook settings by default.") + with open(settings_filename, "r") as f: + jdata = json.loads(f.read()) + if jdata.get("prompts") is not None: + animation_prompts = jdata["prompts"] + for i, k in enumerate(args_dict): + if k in jdata: + args_dict[k] = jdata[k] + else: + if verbose: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {args_dict[k]}") + for i, k in enumerate(anim_args_dict): + if k in jdata: + anim_args_dict[k] = jdata[k] + else: + if verbose: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {anim_args_dict[k]}") + if verbose: + print(args_dict) + print(anim_args_dict) \ No newline at end of file diff --git a/deforum-stable-diffusion/helpers/simulacra_compute_embeddings.py b/deforum-stable-diffusion/helpers/simulacra_compute_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd3bfc1db73542a77ca1f9be438877f314fd1e6 --- /dev/null +++ b/deforum-stable-diffusion/helpers/simulacra_compute_embeddings.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + +"""Precomputes CLIP embeddings for Simulacra Aesthetic Captions.""" + +import argparse +import os +from pathlib import Path +import sqlite3 + +from PIL import Image + +import torch +from torch import multiprocessing as mp +from torch.utils import data +import torchvision.transforms as transforms +from tqdm import tqdm + +from CLIP import clip + + +class SimulacraDataset(data.Dataset): + """Simulacra dataset + Args: + images_dir: directory + transform: preprocessing and augmentation of the training images + """ + + def __init__(self, images_dir, db, transform=None): + self.images_dir = Path(images_dir) + self.transform = transform + self.conn = sqlite3.connect(db) + self.ratings = [] + for row in self.conn.execute('SELECT generations.id, images.idx, paths.path, AVG(ratings.rating) FROM images JOIN generations ON images.gid=generations.id JOIN ratings ON images.id=ratings.iid JOIN paths ON images.id=paths.iid GROUP BY images.id'): + self.ratings.append(row) + + def __len__(self): + return len(self.ratings) + + def __getitem__(self, key): + gid, idx, filename, rating = self.ratings[key] + image = Image.open(self.images_dir / filename).convert('RGB') + if self.transform: + image = self.transform(image) + return image, torch.tensor(rating) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument('--batch-size', '-bs', type=int, default=10, + help='the CLIP model') + p.add_argument('--clip-model', type=str, default='ViT-B/16', + help='the CLIP model') + p.add_argument('--db', type=str, required=True, + help='the database location') + p.add_argument('--device', type=str, + help='the device to use') + p.add_argument('--images-dir', type=str, required=True, + help='the dataset images directory') + p.add_argument('--num-workers', type=int, default=8, + help='the number of data loader workers') + p.add_argument('--output', type=str, required=True, + help='the output file') + p.add_argument('--start-method', type=str, default='spawn', + choices=['fork', 'forkserver', 'spawn'], + help='the multiprocessing start method') + args = p.parse_args() + + mp.set_start_method(args.start_method) + if args.device: + device = torch.device(device) + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print('Using device:', device) + + clip_model, clip_tf = clip.load(args.clip_model, device=device, jit=False) + clip_model = clip_model.eval().requires_grad_(False) + + dataset = SimulacraDataset(args.images_dir, args.db, transform=clip_tf) + loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers) + + embeds, ratings = [], [] + + for batch in tqdm(loader): + images_batch, ratings_batch = batch + embeds.append(clip_model.encode_image(images_batch.to(device)).cpu()) + ratings.append(ratings_batch.clone()) + + obj = {'clip_model': args.clip_model, + 'embeds': torch.cat(embeds), + 'ratings': torch.cat(ratings)} + + torch.save(obj, args.output) + + +if __name__ == '__main__': + main() diff --git a/deforum-stable-diffusion/helpers/simulacra_fit_linear_model.py b/deforum-stable-diffusion/helpers/simulacra_fit_linear_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0a80e77f406ca068fd2912040585f2086e7b5436 --- /dev/null +++ b/deforum-stable-diffusion/helpers/simulacra_fit_linear_model.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 + +"""Fits a linear aesthetic model to precomputed CLIP embeddings.""" + +import argparse + +import numpy as np +from sklearn.linear_model import Ridge +from sklearn.model_selection import train_test_split +import torch +from torch import nn +from torch.nn import functional as F + + +class AestheticMeanPredictionLinearModel(nn.Module): + def __init__(self, feats_in): + super().__init__() + self.linear = nn.Linear(feats_in, 1) + + def forward(self, input): + x = F.normalize(input, dim=-1) * input.shape[-1] ** 0.5 + return self.linear(x) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument('input', type=str, help='the input feature vectors') + p.add_argument('output', type=str, help='the output model') + p.add_argument('--val-size', type=float, default=0.1, help='the validation set size') + p.add_argument('--seed', type=int, default=0, help='the random seed') + args = p.parse_args() + + train_set = torch.load(args.input, map_location='cpu') + X = F.normalize(train_set['embeds'].float(), dim=-1).numpy() + X *= X.shape[-1] ** 0.5 + y = train_set['ratings'].numpy() + X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=args.val_size, random_state=args.seed) + regression = Ridge() + regression.fit(X_train, y_train) + score_train = regression.score(X_train, y_train) + score_val = regression.score(X_val, y_val) + print(f'Score on train: {score_train:g}') + print(f'Score on val: {score_val:g}') + model = AestheticMeanPredictionLinearModel(X_train.shape[1]) + with torch.no_grad(): + model.linear.weight.copy_(torch.tensor(regression.coef_)) + model.linear.bias.copy_(torch.tensor(regression.intercept_)) + torch.save(model.state_dict(), args.output) + + +if __name__ == '__main__': + main() diff --git a/deforum-stable-diffusion/helpers/video.py b/deforum-stable-diffusion/helpers/video.py new file mode 100644 index 0000000000000000000000000000000000000000..11171719b84fc33fa32b564a0c0433b75734734f --- /dev/null +++ b/deforum-stable-diffusion/helpers/video.py @@ -0,0 +1,25 @@ +import os +import pathlib + +def vid2frames(video_path, frames_path, n=1, overwrite=True): + if not os.path.exists(frames_path) or overwrite: + try: + for f in pathlib.Path(video_in_frame_path).glob('*.jpg'): + f.unlink() + except: + pass + assert os.path.exists(video_path), f"Video input {video_path} does not exist" + + vidcap = cv2.VideoCapture(video_path) + success,image = vidcap.read() + count = 0 + t=1 + success = True + while success: + if count % n == 0: + cv2.imwrite(frames_path + os.path.sep + f"{t:05}.jpg" , image) # save frame as JPEG file + t += 1 + success,image = vidcap.read() + count += 1 + print("Converted %d frames" % count) + else: print("Frames already unpacked") \ No newline at end of file diff --git a/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_b_16_linear.pth b/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_b_16_linear.pth new file mode 100644 index 0000000000000000000000000000000000000000..31d28bda414d5cb8f10b15ea4ee909ba54ca52a4 --- /dev/null +++ b/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_b_16_linear.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdc27b2196565bcb69f2692731084b0189612b98aae74adec033782469ab583c +size 3111 diff --git a/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_b_32_linear.pth b/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_b_32_linear.pth new file mode 100644 index 0000000000000000000000000000000000000000..3d1f4e4c6426f3e6ae40f5f8483d42c477620c10 --- /dev/null +++ b/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_b_32_linear.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64315370e2293a08cc2f4816ae9db715f935791569c1a57d84b943e482ea5e12 +size 3111 diff --git a/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_l_14_linear.pth b/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_l_14_linear.pth new file mode 100644 index 0000000000000000000000000000000000000000..babe2d0db7544b1993d5ece121a1066ef4ab87da --- /dev/null +++ b/deforum-stable-diffusion/models/sac_public_2022_06_29_vit_l_14_linear.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f77d1d0f4bb04b0ff8e60c5747fd0bfff0cc4a85263a6106627522c06eae85c6 +size 4135 diff --git a/deforum-stable-diffusion/requirements.txt b/deforum-stable-diffusion/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a6051e45364fcc5cf7ee74756e9c75401e843b21 --- /dev/null +++ b/deforum-stable-diffusion/requirements.txt @@ -0,0 +1,28 @@ +clean-fid +colab-convert +einops +ffmpeg +ftfy +ipython +ipywidgets +jsonmerge +jupyterlab +jupyter_http_over_ws +kornia +matplotlib +notebook +numexpr +omegaconf +opencv-python +pandas +pytorch_lightning==1.7.7 +resize-right +scikit-image +scikit-learn +timm +torchdiffeq +transformers==4.19.2 +albumentations +more_itertools +devtools +validators diff --git a/deforum-stable-diffusion/src/adabins/__init__.py b/deforum-stable-diffusion/src/adabins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2a0eea190658f294d0a49363ea28543087bdf6 --- /dev/null +++ b/deforum-stable-diffusion/src/adabins/__init__.py @@ -0,0 +1 @@ +from .unet_adaptive_bins import UnetAdaptiveBins diff --git a/deforum-stable-diffusion/src/adabins/__pycache__/__init__.cpython-38.pyc b/deforum-stable-diffusion/src/adabins/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f48238b5500e489aefd9f043df316e0ceaa3351 Binary files /dev/null and b/deforum-stable-diffusion/src/adabins/__pycache__/__init__.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/src/adabins/__pycache__/layers.cpython-38.pyc b/deforum-stable-diffusion/src/adabins/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc3cd7e4c41c77145e23caa63b0acb923d627260 Binary files /dev/null and b/deforum-stable-diffusion/src/adabins/__pycache__/layers.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/src/adabins/__pycache__/miniViT.cpython-38.pyc b/deforum-stable-diffusion/src/adabins/__pycache__/miniViT.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6289cbdaea9e5331aa652dd353f0fb4ec213d038 Binary files /dev/null and b/deforum-stable-diffusion/src/adabins/__pycache__/miniViT.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/src/adabins/__pycache__/unet_adaptive_bins.cpython-38.pyc b/deforum-stable-diffusion/src/adabins/__pycache__/unet_adaptive_bins.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..210b7f072047defa5e3f40e9cb416527f8bf5a10 Binary files /dev/null and b/deforum-stable-diffusion/src/adabins/__pycache__/unet_adaptive_bins.cpython-38.pyc differ diff --git a/deforum-stable-diffusion/src/adabins/layers.py b/deforum-stable-diffusion/src/adabins/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..499cd8cc1ec5973da5718d184d36b187869f9c28 --- /dev/null +++ b/deforum-stable-diffusion/src/adabins/layers.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4): + super(PatchTransformerEncoder, self).__init__() + encoder_layers = nn.TransformerEncoderLayer(embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + self.positional_encodings = nn.Parameter(torch.rand(500, embedding_dim), requires_grad=True) + + def forward(self, x): + embeddings = self.embedding_convPxP(x).flatten(2) # .shape = n,c,s = n, embedding_dim, s + # embeddings = nn.functional.pad(embeddings, (1,0)) # extra special token at start ? + embeddings = embeddings + self.positional_encodings[:embeddings.shape[2], :].T.unsqueeze(0) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x + + +class PixelWiseDotProduct(nn.Module): + def __init__(self): + super(PixelWiseDotProduct, self).__init__() + + def forward(self, x, K): + n, c, h, w = x.size() + _, cout, ck = K.size() + assert c == ck, "Number of channels in x and Embedding dimension (at dim 2) of K matrix must match" + y = torch.matmul(x.view(n, c, h * w).permute(0, 2, 1), K.permute(0, 2, 1)) # .shape = n, hw, cout + return y.permute(0, 2, 1).view(n, cout, h, w) diff --git a/deforum-stable-diffusion/src/adabins/miniViT.py b/deforum-stable-diffusion/src/adabins/miniViT.py new file mode 100644 index 0000000000000000000000000000000000000000..8a619734aaa82e73fbe37800a6a1dd12e83020a2 --- /dev/null +++ b/deforum-stable-diffusion/src/adabins/miniViT.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn + +from .layers import PatchTransformerEncoder, PixelWiseDotProduct + + +class mViT(nn.Module): + def __init__(self, in_channels, n_query_channels=128, patch_size=16, dim_out=256, + embedding_dim=128, num_heads=4, norm='linear'): + super(mViT, self).__init__() + self.norm = norm + self.n_query_channels = n_query_channels + self.patch_transformer = PatchTransformerEncoder(in_channels, patch_size, embedding_dim, num_heads) + self.dot_product_layer = PixelWiseDotProduct() + + self.conv3x3 = nn.Conv2d(in_channels, embedding_dim, kernel_size=3, stride=1, padding=1) + self.regressor = nn.Sequential(nn.Linear(embedding_dim, 256), + nn.LeakyReLU(), + nn.Linear(256, 256), + nn.LeakyReLU(), + nn.Linear(256, dim_out)) + + def forward(self, x): + # n, c, h, w = x.size() + tgt = self.patch_transformer(x.clone()) # .shape = S, N, E + + x = self.conv3x3(x) + + regression_head, queries = tgt[0, ...], tgt[1:self.n_query_channels + 1, ...] + + # Change from S, N, E to N, S, E + queries = queries.permute(1, 0, 2) + range_attention_maps = self.dot_product_layer(x, queries) # .shape = n, n_query_channels, h, w + + y = self.regressor(regression_head) # .shape = N, dim_out + if self.norm == 'linear': + y = torch.relu(y) + eps = 0.1 + y = y + eps + elif self.norm == 'softmax': + return torch.softmax(y, dim=1), range_attention_maps + else: + y = torch.sigmoid(y) + y = y / y.sum(dim=1, keepdim=True) + return y, range_attention_maps diff --git a/deforum-stable-diffusion/src/adabins/unet_adaptive_bins.py b/deforum-stable-diffusion/src/adabins/unet_adaptive_bins.py new file mode 100644 index 0000000000000000000000000000000000000000..ff06c9b8635dff6edce757795151094a41ddba66 --- /dev/null +++ b/deforum-stable-diffusion/src/adabins/unet_adaptive_bins.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from geffnet import tf_efficientnet_b5_ap +from .miniViT import mViT + + +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleBN, self).__init__() + + self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +class DecoderBN(nn.Module): + def __init__(self, num_features=2048, num_classes=1, bottleneck_features=2048): + super(DecoderBN, self).__init__() + features = int(num_features) + + self.conv2 = nn.Conv2d(bottleneck_features, features, kernel_size=1, stride=1, padding=1) + + self.up1 = UpSampleBN(skip_input=features // 1 + 112 + 64, output_features=features // 2) + self.up2 = UpSampleBN(skip_input=features // 2 + 40 + 24, output_features=features // 4) + self.up3 = UpSampleBN(skip_input=features // 4 + 24 + 16, output_features=features // 8) + self.up4 = UpSampleBN(skip_input=features // 8 + 16 + 8, output_features=features // 16) + + # self.up5 = UpSample(skip_input=features // 16 + 3, output_features=features//16) + self.conv3 = nn.Conv2d(features // 16, num_classes, kernel_size=3, stride=1, padding=1) + # self.act_out = nn.Softmax(dim=1) if output_activation == 'softmax' else nn.Identity() + + def forward(self, features): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[ + 11] + + x_d0 = self.conv2(x_block4) + + x_d1 = self.up1(x_d0, x_block3) + x_d2 = self.up2(x_d1, x_block2) + x_d3 = self.up3(x_d2, x_block1) + x_d4 = self.up4(x_d3, x_block0) + # x_d5 = self.up5(x_d4, features[0]) + out = self.conv3(x_d4) + # out = self.act_out(out) + # if with_features: + # return out, features[-1] + # elif with_intermediate: + # return out, [x_block0, x_block1, x_block2, x_block3, x_block4, x_d1, x_d2, x_d3, x_d4] + return out + + +class Encoder(nn.Module): + def __init__(self, backend): + super(Encoder, self).__init__() + self.original_model = backend + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +class UnetAdaptiveBins(nn.Module): + def __init__(self, backend, n_bins=100, min_val=0.1, max_val=10, norm='linear'): + super(UnetAdaptiveBins, self).__init__() + self.num_classes = n_bins + self.min_val = min_val + self.max_val = max_val + self.encoder = Encoder(backend) + self.adaptive_bins_layer = mViT(128, n_query_channels=128, patch_size=16, + dim_out=n_bins, + embedding_dim=128, norm=norm) + + self.decoder = DecoderBN(num_classes=128) + self.conv_out = nn.Sequential(nn.Conv2d(128, n_bins, kernel_size=1, stride=1, padding=0), + nn.Softmax(dim=1)) + + def forward(self, x, **kwargs): + unet_out = self.decoder(self.encoder(x), **kwargs) + bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(unet_out) + out = self.conv_out(range_attention_maps) + + # Post process + # n, c, h, w = out.shape + # hist = torch.sum(out.view(n, c, h * w), dim=2) / (h * w) # not used for training + + bin_widths = (self.max_val - self.min_val) * bin_widths_normed # .shape = N, dim_out + bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_val) + bin_edges = torch.cumsum(bin_widths, dim=1) + + centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:]) + n, dout = centers.size() + centers = centers.view(n, dout, 1, 1) + + pred = torch.sum(out * centers, dim=1, keepdim=True) + + return bin_edges, pred + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + modules = [self.decoder, self.adaptive_bins_layer, self.conv_out] + for m in modules: + yield from m.parameters() + + @classmethod + def build(cls, n_bins, **kwargs): + basemodel_name = 'tf_efficientnet_b5_ap' + + print('Loading base model ()...'.format(basemodel_name), end='') + basemodel = tf_efficientnet_b5_ap(pretrained=False) + # basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) + print('Done.') + + # Remove last layer + print('Removing last two layers (global_pool & classifier).') + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + # Building Encoder-Decoder model + print('Building Encoder-Decoder model..', end='') + m = cls(basemodel, n_bins=n_bins, **kwargs) + print('Done.') + return m + + +if __name__ == '__main__': + model = UnetAdaptiveBins.build(100) + x = torch.rand(2, 3, 480, 640) + bins, pred = model(x) + print(bins.shape, pred.shape) diff --git a/deforum-stable-diffusion/src/clip/__init__.py b/deforum-stable-diffusion/src/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/deforum-stable-diffusion/src/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/deforum-stable-diffusion/src/clip/bpe_simple_vocab_16e6.txt.gz b/deforum-stable-diffusion/src/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/deforum-stable-diffusion/src/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/deforum-stable-diffusion/src/clip/clip.py b/deforum-stable-diffusion/src/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..257511e1d40c120e0d64a0f1562d44b2b8a40a17 --- /dev/null +++ b/deforum-stable-diffusion/src/clip/clip.py @@ -0,0 +1,237 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/deforum-stable-diffusion/src/clip/model.py b/deforum-stable-diffusion/src/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..232b7792eb97440642547bd462cf128df9243933 --- /dev/null +++ b/deforum-stable-diffusion/src/clip/model.py @@ -0,0 +1,436 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/deforum-stable-diffusion/src/clip/simple_tokenizer.py b/deforum-stable-diffusion/src/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/deforum-stable-diffusion/src/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/deforum-stable-diffusion/src/infer.py b/deforum-stable-diffusion/src/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..82f831ba1991828bbd3b09b6c3f1c7703f2604b7 --- /dev/null +++ b/deforum-stable-diffusion/src/infer.py @@ -0,0 +1,161 @@ +import glob +import os + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from torchvision import transforms +from tqdm import tqdm + +import model_io +import utils +from adabins import UnetAdaptiveBins + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +class ToTensor(object): + def __init__(self): + self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def __call__(self, image, target_size=(640, 480)): + # image = image.resize(target_size) + image = self.to_tensor(image) + image = self.normalize(image) + return image + + def to_tensor(self, pic): + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError( + 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class InferenceHelper: + def __init__(self, models_path, dataset='nyu', device='cuda:0'): + self.toTensor = ToTensor() + self.device = device + if dataset == 'nyu': + self.min_depth = 1e-3 + self.max_depth = 10 + self.saving_factor = 1000 # used to save in 16 bit + model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) + pretrained_path = os.path.join(models_path,'AdaBins_nyu.pt') + elif dataset == 'kitti': + self.min_depth = 1e-3 + self.max_depth = 80 + self.saving_factor = 256 + model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) + pretrained_path = "./models/AdaBins_kitti.pt" + else: + raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset)) + + model, _, _ = model_io.load_checkpoint(pretrained_path, model) + model.eval() + self.model = model.to(self.device) + + @torch.no_grad() + def predict_pil(self, pil_image, visualized=False): + # pil_image = pil_image.resize((640, 480)) + img = np.asarray(pil_image) / 255. + + img = self.toTensor(img).unsqueeze(0).float().to(self.device) + bin_centers, pred = self.predict(img) + + if visualized: + viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma') + # pred = np.asarray(pred*1000, dtype='uint16') + viz = Image.fromarray(viz) + return bin_centers, pred, viz + return bin_centers, pred + + @torch.no_grad() + def predict(self, image): + bins, pred = self.model(image) + pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth) + + # Flip + image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device) + pred_lr = self.model(image)[-1] + pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth) + + # Take average of original and mirror + final = 0.5 * (pred + pred_lr) + final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:], + mode='bilinear', align_corners=True).cpu().numpy() + + final[final < self.min_depth] = self.min_depth + final[final > self.max_depth] = self.max_depth + final[np.isinf(final)] = self.max_depth + final[np.isnan(final)] = self.min_depth + + centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) + centers = centers.cpu().squeeze().numpy() + centers = centers[centers > self.min_depth] + centers = centers[centers < self.max_depth] + + return centers, final + + @torch.no_grad() + def predict_dir(self, test_dir, out_dir): + os.makedirs(out_dir, exist_ok=True) + transform = ToTensor() + all_files = glob.glob(os.path.join(test_dir, "*")) + self.model.eval() + for f in tqdm(all_files): + image = np.asarray(Image.open(f), dtype='float32') / 255. + image = transform(image).unsqueeze(0).to(self.device) + + centers, final = self.predict(image) + # final = final.squeeze().cpu().numpy() + + final = (final * self.saving_factor).astype('uint16') + basename = os.path.basename(f).split('.')[0] + save_path = os.path.join(out_dir, basename + ".png") + + Image.fromarray(final.squeeze()).save(save_path) + + +if __name__ == '__main__': + import matplotlib.pyplot as plt + from time import time + + img = Image.open("test_imgs/classroom__rgb_00283.jpg") + start = time() + inferHelper = InferenceHelper() + centers, pred = inferHelper.predict_pil(img) + print(f"took :{time() - start}s") + plt.imshow(pred.squeeze(), cmap='magma_r') + plt.show() diff --git a/deforum-stable-diffusion/src/k_diffusion/__init__.py b/deforum-stable-diffusion/src/k_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deforum-stable-diffusion/src/k_diffusion/__pycache__/__init__.cpython-39.pyc b/deforum-stable-diffusion/src/k_diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58eadd516e7dc30e58f31e5dc2a923227bf1147e Binary files /dev/null and b/deforum-stable-diffusion/src/k_diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/k_diffusion/__pycache__/external.cpython-39.pyc b/deforum-stable-diffusion/src/k_diffusion/__pycache__/external.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d3c2301975675502dd41677e3a0a5517b545f3e Binary files /dev/null and b/deforum-stable-diffusion/src/k_diffusion/__pycache__/external.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/k_diffusion/__pycache__/sampling.cpython-39.pyc b/deforum-stable-diffusion/src/k_diffusion/__pycache__/sampling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ad6a6d1f5b099998962599df71a6aadefb30399 Binary files /dev/null and b/deforum-stable-diffusion/src/k_diffusion/__pycache__/sampling.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/k_diffusion/augmentation.py b/deforum-stable-diffusion/src/k_diffusion/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd17c686300c8ecba7fac134aa54f01619c3d46 --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/augmentation.py @@ -0,0 +1,105 @@ +from functools import reduce +import math +import operator + +import numpy as np +from skimage import transform +import torch +from torch import nn + + +def translate2d(tx, ty): + mat = [[1, 0, tx], + [0, 1, ty], + [0, 0, 1]] + return torch.tensor(mat, dtype=torch.float32) + + +def scale2d(sx, sy): + mat = [[sx, 0, 0], + [ 0, sy, 0], + [ 0, 0, 1]] + return torch.tensor(mat, dtype=torch.float32) + + +def rotate2d(theta): + mat = [[torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [ 0, 0, 1]] + return torch.tensor(mat, dtype=torch.float32) + + +class KarrasAugmentationPipeline: + def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): + self.a_prob = a_prob + self.a_scale = a_scale + self.a_aniso = a_aniso + self.a_trans = a_trans + + def __call__(self, image): + h, w = image.size + mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] + + # x-flip + a0 = torch.randint(2, []).float() + mats.append(scale2d(1 - 2 * a0, 1)) + # y-flip + do = (torch.rand([]) < self.a_prob).float() + a1 = torch.randint(2, []).float() * do + mats.append(scale2d(1, 1 - 2 * a1)) + # scaling + do = (torch.rand([]) < self.a_prob).float() + a2 = torch.randn([]) * do + mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) + # rotation + do = (torch.rand([]) < self.a_prob).float() + a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do + mats.append(rotate2d(-a3)) + # anisotropy + do = (torch.rand([]) < self.a_prob).float() + a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do + a5 = torch.randn([]) * do + mats.append(rotate2d(a4)) + mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) + mats.append(rotate2d(-a4)) + # translation + do = (torch.rand([]) < self.a_prob).float() + a6 = torch.randn([]) * do + a7 = torch.randn([]) * do + mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) + + # form the transformation matrix and conditioning vector + mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) + mat = reduce(operator.matmul, mats) + cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) + + # apply the transformation + image_orig = np.array(image, dtype=np.float32) / 255 + if image_orig.ndim == 2: + image_orig = image_orig[..., None] + tf = transform.AffineTransform(mat.numpy()) + image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) + image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 + image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 + return image, image_orig, cond + + +class KarrasAugmentWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): + if aug_cond is None: + aug_cond = input.new_zeros([input.shape[0], 9]) + if mapping_cond is None: + mapping_cond = aug_cond + else: + mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) + return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) + + def set_skip_stages(self, skip_stages): + return self.inner_model.set_skip_stages(skip_stages) + + def set_patch_size(self, patch_size): + return self.inner_model.set_patch_size(patch_size) diff --git a/deforum-stable-diffusion/src/k_diffusion/config.py b/deforum-stable-diffusion/src/k_diffusion/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4b504d6d74b2fbdf92be6aa6f84955832f8c701a --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/config.py @@ -0,0 +1,110 @@ +from functools import partial +import json +import math +import warnings + +from jsonmerge import merge + +from . import augmentation, layers, models, utils + + +def load_config(file): + defaults = { + 'model': { + 'sigma_data': 1., + 'patch_size': 1, + 'dropout_rate': 0., + 'augment_wrapper': True, + 'augment_prob': 0., + 'mapping_cond_dim': 0, + 'unet_cond_dim': 0, + 'cross_cond_dim': 0, + 'cross_attn_depths': None, + 'skip_stages': 0, + 'has_variance': False, + }, + 'dataset': { + 'type': 'imagefolder', + }, + 'optimizer': { + 'type': 'adamw', + 'lr': 1e-4, + 'betas': [0.95, 0.999], + 'eps': 1e-6, + 'weight_decay': 1e-3, + }, + 'lr_sched': { + 'type': 'inverse', + 'inv_gamma': 20000., + 'power': 1., + 'warmup': 0.99, + }, + 'ema_sched': { + 'type': 'inverse', + 'power': 0.6667, + 'max_value': 0.9999 + }, + } + config = json.load(file) + return merge(defaults, config) + + +def make_model(config): + config = config['model'] + assert config['type'] == 'image_v1' + model = models.ImageDenoiserModelV1( + config['input_channels'], + config['mapping_out'], + config['depths'], + config['channels'], + config['self_attn_depths'], + config['cross_attn_depths'], + patch_size=config['patch_size'], + dropout_rate=config['dropout_rate'], + mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), + unet_cond_dim=config['unet_cond_dim'], + cross_cond_dim=config['cross_cond_dim'], + skip_stages=config['skip_stages'], + has_variance=config['has_variance'], + ) + if config['augment_wrapper']: + model = augmentation.KarrasAugmentWrapper(model) + return model + + +def make_denoiser_wrapper(config): + config = config['model'] + sigma_data = config.get('sigma_data', 1.) + has_variance = config.get('has_variance', False) + if not has_variance: + return partial(layers.Denoiser, sigma_data=sigma_data) + return partial(layers.DenoiserWithVariance, sigma_data=sigma_data) + + +def make_sample_density(config): + sd_config = config['sigma_sample_density'] + sigma_data = config['sigma_data'] + if sd_config['type'] == 'lognormal': + loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] + scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] + return partial(utils.rand_log_normal, loc=loc, scale=scale) + if sd_config['type'] == 'loglogistic': + loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) + scale = sd_config['scale'] if 'scale' in sd_config else 0.5 + min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. + max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') + return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) + if sd_config['type'] == 'loguniform': + min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] + max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] + return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) + if sd_config['type'] == 'v-diffusion': + min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. + max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') + return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) + if sd_config['type'] == 'split-lognormal': + loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] + scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] + scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] + return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) + raise ValueError('Unknown sample density type') diff --git a/deforum-stable-diffusion/src/k_diffusion/evaluation.py b/deforum-stable-diffusion/src/k_diffusion/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..2c34bbf1656854d9cf233b7620b684e44b30de82 --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/evaluation.py @@ -0,0 +1,134 @@ +import math +import os +from pathlib import Path + +from cleanfid.inception_torchscript import InceptionV3W +import clip +from resize_right import resize +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms +from tqdm.auto import trange + +from . import utils + + +class InceptionV3FeatureExtractor(nn.Module): + def __init__(self, device='cpu'): + super().__init__() + path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' + url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' + digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' + utils.download_file(path / 'inception-2015-12-05.pt', url, digest) + self.model = InceptionV3W(str(path), resize_inside=False).to(device) + self.size = (299, 299) + + def forward(self, x): + if x.shape[2:4] != self.size: + x = resize(x, out_shape=self.size, pad_mode='reflect') + if x.shape[1] == 1: + x = torch.cat([x] * 3, dim=1) + x = (x * 127.5 + 127.5).clamp(0, 255) + return self.model(x) + + +class CLIPFeatureExtractor(nn.Module): + def __init__(self, name='ViT-L/14@336px', device='cpu'): + super().__init__() + self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) + self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)) + self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) + + def forward(self, x): + if x.shape[2:4] != self.size: + x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) + x = self.normalize(x) + x = self.model.encode_image(x).float() + x = F.normalize(x) * x.shape[1] ** 0.5 + return x + + +def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): + n_per_proc = math.ceil(n / accelerator.num_processes) + feats_all = [] + try: + for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): + cur_batch_size = min(n - i, batch_size) + samples = sample_fn(cur_batch_size)[:cur_batch_size] + feats_all.append(accelerator.gather(extractor_fn(samples))) + except StopIteration: + pass + return torch.cat(feats_all)[:n] + + +def polynomial_kernel(x, y): + d = x.shape[-1] + dot = x @ y.transpose(-2, -1) + return (dot / d + 1) ** 3 + + +def squared_mmd(x, y, kernel=polynomial_kernel): + m = x.shape[-2] + n = y.shape[-2] + kxx = kernel(x, x) + kyy = kernel(y, y) + kxy = kernel(x, y) + kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1) + kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1) + kxy_sum = kxy.sum([-1, -2]) + term_1 = kxx_sum / m / (m - 1) + term_2 = kyy_sum / n / (n - 1) + term_3 = kxy_sum * 2 / m / n + return term_1 + term_2 - term_3 + + +@utils.tf32_mode(matmul=False) +def kid(x, y, max_size=5000): + x_size, y_size = x.shape[0], y.shape[0] + n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) + total_mmd = x.new_zeros([]) + for i in range(n_partitions): + cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)] + cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)] + total_mmd = total_mmd + squared_mmd(cur_x, cur_y) + return total_mmd / n_partitions + + +class _MatrixSquareRootEig(torch.autograd.Function): + @staticmethod + def forward(ctx, a): + vals, vecs = torch.linalg.eigh(a) + ctx.save_for_backward(vals, vecs) + return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) + + @staticmethod + def backward(ctx, grad_output): + vals, vecs = ctx.saved_tensors + d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) + vecs_t = vecs.transpose(-2, -1) + return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t + + +def sqrtm_eig(a): + if a.ndim < 2: + raise RuntimeError('tensor of matrices must have at least 2 dimensions') + if a.shape[-2] != a.shape[-1]: + raise RuntimeError('tensor must be batches of square matrices') + return _MatrixSquareRootEig.apply(a) + + +@utils.tf32_mode(matmul=False) +def fid(x, y, eps=1e-8): + x_mean = x.mean(dim=0) + y_mean = y.mean(dim=0) + mean_term = (x_mean - y_mean).pow(2).sum() + x_cov = torch.cov(x.T) + y_cov = torch.cov(y.T) + eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps + x_cov = x_cov + eps_eye + y_cov = y_cov + eps_eye + x_cov_sqrt = sqrtm_eig(x_cov) + cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) + return mean_term + cov_term diff --git a/deforum-stable-diffusion/src/k_diffusion/external.py b/deforum-stable-diffusion/src/k_diffusion/external.py new file mode 100644 index 0000000000000000000000000000000000000000..2f1d2588481dc3ab97b3285af87629a17950ae2c --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/external.py @@ -0,0 +1,138 @@ +import math + +import torch +from torch import nn + +from . import sampling, utils + + +class VDenoiser(nn.Module): + """A v-diffusion-pytorch model wrapper for k-diffusion.""" + + def __init__(self, inner_model): + super().__init__() + self.inner_model = inner_model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_skip, c_out, c_in + + def sigma_to_t(self, sigma): + return sigma.atan() / math.pi * 2 + + def t_to_sigma(self, t): + return (t * math.pi / 2).tan() + + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + target = (input - c_skip * noised_input) / c_out + return (model_output - target).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip + + +class DiscreteSchedule(nn.Module): + """A mapping between continuous noise levels (sigmas) and a list of discrete noise + levels.""" + + def __init__(self, sigmas, quantize): + super().__init__() + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + self.quantize = quantize + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def get_sigmas(self, n=None): + if n is None: + return sampling.append_zero(self.sigmas.flip(0)) + t_max = len(self.sigmas) - 1 + t = torch.linspace(t_max, 0, n, device=self.sigmas.device) + return sampling.append_zero(self.t_to_sigma(t)) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize + log_sigma = sigma.log() + dists = log_sigma - self.log_sigmas[:, None] + if quantize: + return dists.abs().argmin(dim=0).view(sigma.shape) + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + +class DiscreteEpsDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output eps (the predicted + noise).""" + + def __init__(self, model, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_out, c_in + + def get_eps(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + return (eps - noise).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return input + eps * c_out + + +class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for OpenAI diffusion models.""" + + def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): + alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) + super().__init__(model, alphas_cumprod, quantize=quantize) + self.has_learned_sigmas = has_learned_sigmas + + def get_eps(self, *args, **kwargs): + model_output = self.inner_model(*args, **kwargs) + if self.has_learned_sigmas: + return model_output.chunk(2, dim=1)[0] + return model_output + + +class CompVisDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for CompVis diffusion models.""" + + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) diff --git a/deforum-stable-diffusion/src/k_diffusion/gns.py b/deforum-stable-diffusion/src/k_diffusion/gns.py new file mode 100644 index 0000000000000000000000000000000000000000..dcb7b8d8a9aeae38a7f961c63f66cca4ef90a9e7 --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/gns.py @@ -0,0 +1,99 @@ +import torch +from torch import nn + + +class DDPGradientStatsHook: + def __init__(self, ddp_module): + try: + ddp_module.register_comm_hook(self, self._hook_fn) + except AttributeError: + raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') + self._clear_state() + + def _clear_state(self): + self.bucket_sq_norms_small_batch = [] + self.bucket_sq_norms_large_batch = [] + + @staticmethod + def _hook_fn(self, bucket): + buf = bucket.buffer() + self.bucket_sq_norms_small_batch.append(buf.pow(2).sum()) + fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() + def callback(fut): + buf = fut.value()[0] + self.bucket_sq_norms_large_batch.append(buf.pow(2).sum()) + return buf + return fut.then(callback) + + def get_stats(self): + sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) + sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) + self._clear_state() + stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch]) + torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG) + return stats[0].item(), stats[1].item() + + +class GradientNoiseScale: + """Calculates the gradient noise scale (1 / SNR), or critical batch size, + from _An Empirical Model of Large-Batch Training_, + https://arxiv.org/abs/1812.06162). + + Args: + beta (float): The decay factor for the exponential moving averages used to + calculate the gradient noise scale. + Default: 0.9998 + eps (float): Added for numerical stability. + Default: 1e-8 + """ + + def __init__(self, beta=0.9998, eps=1e-8): + self.beta = beta + self.eps = eps + self.ema_sq_norm = 0. + self.ema_var = 0. + self.beta_cumprod = 1. + self.gradient_noise_scale = float('nan') + + def state_dict(self): + """Returns the state of the object as a :class:`dict`.""" + return dict(self.__dict__.items()) + + def load_state_dict(self, state_dict): + """Loads the object's state. + Args: + state_dict (dict): object state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): + """Updates the state with a new batch's gradient statistics, and returns the + current gradient noise scale. + + Args: + sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or + per sample gradients. + sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or + per sample gradients. + n_small_batch (int): The batch size of the individual microbatch or per sample + gradients (1 if per sample). + n_large_batch (int): The total batch size of the mean of the microbatch or + per sample gradients. + """ + est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) + est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) + self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm + self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var + self.beta_cumprod *= self.beta + self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) + return self.gradient_noise_scale + + def get_gns(self): + """Returns the current gradient noise scale.""" + return self.gradient_noise_scale + + def get_stats(self): + """Returns the current (debiased) estimates of the squared mean gradient + and gradient variance.""" + return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) diff --git a/deforum-stable-diffusion/src/k_diffusion/layers.py b/deforum-stable-diffusion/src/k_diffusion/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeba0ad68f584261bd88de608e843a350489544 --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/layers.py @@ -0,0 +1,246 @@ +import math + +from einops import rearrange, repeat +import torch +from torch import nn +from torch.nn import functional as F + +from . import utils + +# Karras et al. preconditioned denoiser + +class Denoiser(nn.Module): + """A Karras et al. preconditioner for denoising diffusion models.""" + + def __init__(self, inner_model, sigma_data=1.): + super().__init__() + self.inner_model = inner_model + self.sigma_data = sigma_data + + def get_scalings(self, sigma): + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_skip, c_out, c_in + + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) + target = (input - c_skip * noised_input) / c_out + return (model_output - target).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip + + +class DenoiserWithVariance(Denoiser): + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs) + logvar = utils.append_dims(logvar, model_output.ndim) + target = (input - c_skip * noised_input) / c_out + losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2 + return losses.flatten(1).mean(1) + + +# Residual blocks + +class ResidualBlock(nn.Module): + def __init__(self, *main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + + +# Noise level (and other) conditioning + +class ConditionedModule(nn.Module): + pass + + +class UnconditionedModule(ConditionedModule): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, input, cond=None): + return self.module(input) + + +class ConditionedSequential(nn.Sequential, ConditionedModule): + def forward(self, input, cond): + for module in self: + if isinstance(module, ConditionedModule): + input = module(input, cond) + else: + input = module(input) + return input + + +class ConditionedResidualBlock(ConditionedModule): + def __init__(self, *main, skip=None): + super().__init__() + self.main = ConditionedSequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input, cond): + skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) + return self.main(input, cond) + skip + + +class AdaGN(ConditionedModule): + def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): + super().__init__() + self.num_groups = num_groups + self.eps = eps + self.cond_key = cond_key + self.mapper = nn.Linear(feats_in, c_out * 2) + + def forward(self, input, cond): + weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) + input = F.group_norm(input, self.num_groups, eps=self.eps) + return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) + + +# Attention + +class SelfAttention2d(ConditionedModule): + def __init__(self, c_in, n_head, norm, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm_in = norm(c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv2d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, input, cond): + n, c, h, w = input.shape + qkv = self.qkv_proj(self.norm_in(input, cond)) + qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3] ** -0.25 + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + att = self.dropout(att) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w]) + return input + self.out_proj(y) + + +class CrossAttention2d(ConditionedModule): + def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0., + cond_key='cross', cond_key_padding='cross_padding'): + super().__init__() + assert c_dec % n_head == 0 + self.cond_key = cond_key + self.cond_key_padding = cond_key_padding + self.norm_enc = nn.LayerNorm(c_enc) + self.norm_dec = norm_dec(c_dec) + self.n_head = n_head + self.q_proj = nn.Conv2d(c_dec, c_dec, 1) + self.kv_proj = nn.Linear(c_enc, c_dec * 2) + self.out_proj = nn.Conv2d(c_dec, c_dec, 1) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, input, cond): + n, c, h, w = input.shape + q = self.q_proj(self.norm_dec(input, cond)) + q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3) + kv = self.kv_proj(self.norm_enc(cond[self.cond_key])) + kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2) + k, v = kv.chunk(2, dim=1) + scale = k.shape[3] ** -0.25 + att = ((q * scale) @ (k.transpose(2, 3) * scale)) + att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000 + att = att.softmax(3) + att = self.dropout(att) + y = (att @ v).transpose(2, 3) + y = y.contiguous().view([n, c, h, w]) + return input + self.out_proj(y) + + +# Downsampling/upsampling + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} +_kernels['bilinear'] = _kernels['linear'] +_kernels['bicubic'] = _kernels['cubic'] + + +class Downsample2d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect'): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([_kernels[kernel]]) + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer('kernel', kernel_1d.T @ kernel_1d) + + def forward(self, x): + x = F.pad(x, (self.pad,) * 4, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv2d(x, weight, stride=2) + + +class Upsample2d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect'): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor([_kernels[kernel]]) * 2 + self.pad = kernel_1d.shape[1] // 2 - 1 + self.register_buffer('kernel', kernel_1d.T @ kernel_1d) + + def forward(self, x): + x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) + + +# Embeddings + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + + +# U-Nets + +class UNet(ConditionedModule): + def __init__(self, d_blocks, u_blocks, skip_stages=0): + super().__init__() + self.d_blocks = nn.ModuleList(d_blocks) + self.u_blocks = nn.ModuleList(u_blocks) + self.skip_stages = skip_stages + + def forward(self, input, cond): + skips = [] + for block in self.d_blocks[self.skip_stages:]: + input = block(input, cond) + skips.append(input) + for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): + input = block(input, cond, skip if i > 0 else None) + return input diff --git a/deforum-stable-diffusion/src/k_diffusion/models/__init__.py b/deforum-stable-diffusion/src/k_diffusion/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82608ff1de6137b31eeaf8de6814df6a7e35606a --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/models/__init__.py @@ -0,0 +1 @@ +from .image_v1 import ImageDenoiserModelV1 diff --git a/deforum-stable-diffusion/src/k_diffusion/models/image_v1.py b/deforum-stable-diffusion/src/k_diffusion/models/image_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffd5f2c4d6c9d086107d5fac67452419696c723 --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/models/image_v1.py @@ -0,0 +1,156 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from .. import layers, utils + + +def orthogonal_(module): + nn.init.orthogonal_(module.weight) + return module + + +class ResConvBlock(layers.ConditionedResidualBlock): + def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): + skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) + super().__init__( + layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), + nn.GELU(), + nn.Conv2d(c_in, c_mid, 3, padding=1), + nn.Dropout2d(dropout_rate, inplace=True), + layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)), + nn.GELU(), + nn.Conv2d(c_mid, c_out, 3, padding=1), + nn.Dropout2d(dropout_rate, inplace=True), + skip=skip) + + +class DBlock(layers.ConditionedSequential): + def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0): + modules = [nn.Identity()] + for i in range(n_layers): + my_c_in = c_in if i == 0 else c_mid + my_c_out = c_mid if i < n_layers - 1 else c_out + modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) + if self_attn: + norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) + modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) + if cross_attn: + norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) + modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) + super().__init__(*modules) + self.set_downsample(downsample) + + def set_downsample(self, downsample): + self[0] = layers.Downsample2d() if downsample else nn.Identity() + return self + + +class UBlock(layers.ConditionedSequential): + def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0): + modules = [] + for i in range(n_layers): + my_c_in = c_in if i == 0 else c_mid + my_c_out = c_mid if i < n_layers - 1 else c_out + modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) + if self_attn: + norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) + modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) + if cross_attn: + norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) + modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) + modules.append(nn.Identity()) + super().__init__(*modules) + self.set_upsample(upsample) + + def forward(self, input, cond, skip=None): + if skip is not None: + input = torch.cat([input, skip], dim=1) + return super().forward(input, cond) + + def set_upsample(self, upsample): + self[-1] = layers.Upsample2d() if upsample else nn.Identity() + return self + + +class MappingNet(nn.Sequential): + def __init__(self, feats_in, feats_out, n_layers=2): + layers = [] + for i in range(n_layers): + layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))) + layers.append(nn.GELU()) + super().__init__(*layers) + + +class ImageDenoiserModelV1(nn.Module): + def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False): + super().__init__() + self.c_in = c_in + self.channels = channels + self.unet_cond_dim = unet_cond_dim + self.patch_size = patch_size + self.has_variance = has_variance + self.timestep_embed = layers.FourierFeatures(1, feats_in) + if mapping_cond_dim > 0: + self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) + self.mapping = MappingNet(feats_in, feats_in) + self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) + self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) + if cross_cond_dim == 0: + cross_attn_depths = [False] * len(self_attn_depths) + d_blocks, u_blocks = [], [] + for i in range(len(depths)): + my_c_in = channels[max(0, i - 1)] + d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) + for i in range(len(depths)): + my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] + my_c_out = channels[max(0, i - 1)] + u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) + self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages) + + def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False): + c_noise = sigma.log() / 4 + timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) + mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) + mapping_out = self.mapping(timestep_embed + mapping_cond_embed) + cond = {'cond': mapping_out} + if unet_cond is not None: + input = torch.cat([input, unet_cond], dim=1) + if cross_cond is not None: + cond['cross'] = cross_cond + cond['cross_padding'] = cross_cond_padding + if self.patch_size > 1: + input = F.pixel_unshuffle(input, self.patch_size) + input = self.proj_in(input) + input = self.u_net(input, cond) + input = self.proj_out(input) + if self.has_variance: + input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1) + if self.patch_size > 1: + input = F.pixel_shuffle(input, self.patch_size) + if self.has_variance and return_variance: + return input, logvar + return input + + def set_skip_stages(self, skip_stages): + self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1) + self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1) + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) + self.u_net.skip_stages = skip_stages + for i, block in enumerate(self.u_net.d_blocks): + block.set_downsample(i > skip_stages) + for i, block in enumerate(reversed(self.u_net.u_blocks)): + block.set_upsample(i > skip_stages) + return self + + def set_patch_size(self, patch_size): + self.patch_size = patch_size + self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1) + self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) diff --git a/deforum-stable-diffusion/src/k_diffusion/sampling.py b/deforum-stable-diffusion/src/k_diffusion/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f679ca185d362219519e767e57b1ad456d68b1 --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/sampling.py @@ -0,0 +1,494 @@ +import math + +from scipy import integrate +import torch +from torch import nn +from torchdiffeq import odeint +from tqdm.auto import trange, tqdm + +from . import utils + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / utils.append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + if not eta: + return sigma_to, 0. + sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +@torch.no_grad() +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + +@torch.no_grad() +def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + x = x + torch.randn_like(x) * sigma_up + return x + + +@torch.no_grad() +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +@torch.no_grad() +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + +@torch.no_grad() +def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): + """Ancestral sampling with DPM-Solver inspired second-order steps.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + if sigma_down == 0: + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + torch.randn_like(x) * sigma_up + return x + + +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + + +@torch.no_grad() +def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigmas_cpu = sigmas.detach().cpu().numpy() + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +class PIDStepSizeController: + """A PID controller for ODE adaptive step size control.""" + def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + self.h = h + self.b1 = (pcoeff + icoeff + dcoeff) / order + self.b2 = -(pcoeff + 2 * dcoeff) / order + self.b3 = dcoeff / order + self.accept_safety = accept_safety + self.eps = eps + self.errs = [] + + def limiter(self, x): + return 1 + math.atan(x - 1) + + def propose_step(self, error): + inv_error = 1 / (float(error) + self.eps) + if not self.errs: + self.errs = [inv_error, inv_error, inv_error] + self.errs[0] = inv_error + factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = self.limiter(factor) + accept = factor >= self.accept_safety + if accept: + self.errs[2] = self.errs[1] + self.errs[1] = self.errs[0] + self.h *= factor + return accept + + +class DPMSolver(nn.Module): + """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" + + def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): + super().__init__() + self.model = model + self.extra_args = {} if extra_args is None else extra_args + self.eps_callback = eps_callback + self.info_callback = info_callback + + def t(self, sigma): + return -sigma.log() + + def sigma(self, t): + return t.neg().exp() + + def eps(self, eps_cache, key, x, t, *args, **kwargs): + if key in eps_cache: + return eps_cache[key], eps_cache + sigma = self.sigma(t) * x.new_ones([x.shape[0]]) + eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + if self.eps_callback is not None: + self.eps_callback() + return eps, {key: eps, **eps_cache} + + def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + x_1 = x - self.sigma(t_next) * h.expm1() * eps + return x_1, eps_cache + + def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + return x_2, eps_cache + + def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + s2 = t + r2 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) + eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) + x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + return x_3, eps_cache + + def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): + if not t_end > t_start and eta: + raise ValueError('eta must be 0 for reverse sampling') + + m = math.floor(nfe / 3) + 1 + ts = torch.linspace(t_start, t_end, m + 1, device=x.device) + + if nfe % 3 == 0: + orders = [3] * (m - 2) + [2, 1] + else: + orders = [3] * (m - 1) + [nfe % 3] + + for i in range(len(orders)): + eps_cache = {} + t, t_next = ts[i], ts[i + 1] + if eta: + sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) + t_next_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 + else: + t_next_, su = t_next, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + denoised = x - self.sigma(t) * eps + if self.info_callback is not None: + self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + + if orders[i] == 1: + x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + elif orders[i] == 2: + x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + else: + x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + + x = x + su * s_noise * torch.randn_like(x) + + return x + + def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.): + if order not in {2, 3}: + raise ValueError('order should be 2 or 3') + forward = t_end > t_start + if not forward and eta: + raise ValueError('eta must be 0 for reverse sampling') + h_init = abs(h_init) * (1 if forward else -1) + atol = torch.tensor(atol) + rtol = torch.tensor(rtol) + s = t_start + x_prev = x + accept = True + pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) + info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + + while s < t_end - 1e-5 if forward else s > t_end + 1e-5: + eps_cache = {} + t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + if eta: + sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) + t_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 + else: + t_, su = t, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + denoised = x - self.sigma(s) * eps + + if order == 2: + x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + else: + x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) + error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 + accept = pid.propose_step(error) + if accept: + x_prev = x_low + x = x_high + su * s_noise * torch.randn_like(x_high) + s = t + info['n_accept'] += 1 + else: + info['n_reject'] += 1 + info['nfe'] += order + info['steps'] += 1 + + if self.info_callback is not None: + self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) + + return x, info + + +@torch.no_grad() +def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1.): + """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(total=n, disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) + return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise) + + +@torch.no_grad() +def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., return_info=False): + """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) + x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) + if return_info: + return x, info + return x + + +@torch.no_grad() +def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1.): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigma_down == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver-2++(2S) + t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) + r = 1 / 2 + h = t_next - t + s = t + r * h + x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 + # Noise addition + x = x + torch.randn_like(x) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + if old_denoised is None or sigmas[i + 1] == 0: + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + else: + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + old_denoised = denoised + return x diff --git a/deforum-stable-diffusion/src/k_diffusion/utils.py b/deforum-stable-diffusion/src/k_diffusion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9afedb99276d55d5b923a04ffb62d403c9dfae93 --- /dev/null +++ b/deforum-stable-diffusion/src/k_diffusion/utils.py @@ -0,0 +1,329 @@ +from contextlib import contextmanager +import hashlib +import math +from pathlib import Path +import shutil +import urllib +import warnings + +from PIL import Image +import torch +from torch import nn, optim +from torch.utils import data +from torchvision.transforms import functional as TF + + +def from_pil_image(x): + """Converts from a PIL image to a tensor.""" + x = TF.to_tensor(x) + if x.ndim == 2: + x = x[..., None] + return x * 2 - 1 + + +def to_pil_image(x): + """Converts from a tensor to a PIL image.""" + if x.ndim == 4: + assert x.shape[0] == 1 + x = x[0] + if x.shape[0] == 1: + x = x[0] + return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) + + +def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): + """Apply passed in transforms for HuggingFace Datasets.""" + images = [transform(image.convert(mode)) for image in examples[image_key]] + return {image_key: images} + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def n_params(module): + """Returns the number of trainable parameters in a module.""" + return sum(p.numel() for p in module.parameters()) + + +def download_file(path, url, digest=None): + """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + with urllib.request.urlopen(url) as response, open(path, 'wb') as f: + shutil.copyfileobj(response, f) + if digest is not None: + file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() + if digest != file_digest: + raise OSError(f'hash of {path} (url: {url}) failed to validate') + return path + + +@contextmanager +def train_mode(model, mode=True): + """A context manager that places a model into training mode and restores + the previous mode on exit.""" + modes = [module.training for module in model.modules()] + try: + yield model.train(mode) + finally: + for i, module in enumerate(model.modules()): + module.training = modes[i] + + +def eval_mode(model): + """A context manager that places a model into evaluation mode and restores + the previous mode on exit.""" + return train_mode(model, False) + + +@torch.no_grad() +def ema_update(model, averaged_model, decay): + """Incorporates updated model parameters into an exponential moving averaged + version of a model. It should be called after each optimizer step.""" + model_params = dict(model.named_parameters()) + averaged_params = dict(averaged_model.named_parameters()) + assert model_params.keys() == averaged_params.keys() + + for name, param in model_params.items(): + averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) + + model_buffers = dict(model.named_buffers()) + averaged_buffers = dict(averaged_model.named_buffers()) + assert model_buffers.keys() == averaged_buffers.keys() + + for name, buf in model_buffers.items(): + averaged_buffers[name].copy_(buf) + + +class EMAWarmup: + """Implements an EMA warmup using an inverse decay schedule. + If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 1. + min_value (float): The minimum EMA decay rate. Default: 0. + max_value (float): The maximum EMA decay rate. Default: 1. + start_at (int): The epoch to start averaging at. Default: 0. + last_epoch (int): The index of last epoch. Default: 0. + """ + + def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, + last_epoch=0): + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + self.start_at = start_at + self.last_epoch = last_epoch + + def state_dict(self): + """Returns the state of the class as a :class:`dict`.""" + return dict(self.__dict__.items()) + + def load_state_dict(self, state_dict): + """Loads the class's state. + Args: + state_dict (dict): scaler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_value(self): + """Gets the current EMA decay rate.""" + epoch = max(0, self.last_epoch - self.start_at) + value = 1 - (1 + epoch / self.inv_gamma) ** -self.power + return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) + + def step(self): + """Updates the step count.""" + self.last_epoch += 1 + + +class InverseLR(optim.lr_scheduler._LRScheduler): + """Implements an inverse decay learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + inv_gamma is the number of steps/epochs required for the learning rate to decay to + (1 / 2)**power of its original value. + Args: + optimizer (Optimizer): Wrapped optimizer. + inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. + power (float): Exponential factor of learning rate decay. Default: 1. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + min_lr (float): The minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., + last_epoch=-1, verbose=False): + self.inv_gamma = inv_gamma + self.power = power + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.min_lr = min_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power + return [warmup * max(self.min_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + + +class ExponentialLR(optim.lr_scheduler._LRScheduler): + """Implements an exponential learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate + continuously by decay (default 0.5) every num_steps steps. + Args: + optimizer (Optimizer): Wrapped optimizer. + num_steps (float): The number of steps to decay the learning rate by decay in. + decay (float): The factor by which to decay the learning rate every num_steps + steps. Default: 0.5. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + min_lr (float): The minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., + last_epoch=-1, verbose=False): + self.num_steps = num_steps + self.decay = decay + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.min_lr = min_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch + return [warmup * max(self.min_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + + +def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): + """Draws samples from an lognormal distribution.""" + return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() + + +def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): + """Draws samples from an optionally truncated log-logistic distribution.""" + min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) + max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) + min_cdf = min_value.log().sub(loc).div(scale).sigmoid() + max_cdf = max_value.log().sub(loc).div(scale).sigmoid() + u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf + return u.logit().mul(scale).add(loc).exp().to(dtype) + + +def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): + """Draws samples from an log-uniform distribution.""" + min_value = math.log(min_value) + max_value = math.log(max_value) + return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() + + +def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): + """Draws samples from a truncated v-diffusion training timestep distribution.""" + min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi + max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi + u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf + return torch.tan(u * math.pi / 2) * sigma_data + + +def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): + """Draws samples from a split lognormal distribution.""" + n = torch.randn(shape, device=device, dtype=dtype).abs() + u = torch.rand(shape, device=device, dtype=dtype) + n_left = n * -scale_1 + loc + n_right = n * scale_2 + loc + ratio = scale_1 / (scale_1 + scale_2) + return torch.where(u < ratio, n_left, n_right).exp() + + +class FolderOfImages(data.Dataset): + """Recursively finds all images in a directory. It does not support + classes/targets.""" + + IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} + + def __init__(self, root, transform=None): + super().__init__() + self.root = Path(root) + self.transform = nn.Identity() if transform is None else transform + self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) + + def __repr__(self): + return f'FolderOfImages(root="{self.root}", len: {len(self)})' + + def __len__(self): + return len(self.paths) + + def __getitem__(self, key): + path = self.paths[key] + with open(path, 'rb') as f: + image = Image.open(f).convert('RGB') + image = self.transform(image) + return image, + + +class CSVLogger: + def __init__(self, filename, columns): + self.filename = Path(filename) + self.columns = columns + if self.filename.exists(): + self.file = open(self.filename, 'a') + else: + self.file = open(self.filename, 'w') + self.write(*self.columns) + + def write(self, *args): + print(*args, sep=',', file=self.file, flush=True) + + +@contextmanager +def tf32_mode(cudnn=None, matmul=None): + """A context manager that sets whether TF32 is allowed on cuDNN or matmul.""" + cudnn_old = torch.backends.cudnn.allow_tf32 + matmul_old = torch.backends.cuda.matmul.allow_tf32 + try: + if cudnn is not None: + torch.backends.cudnn.allow_tf32 = cudnn + if matmul is not None: + torch.backends.cuda.matmul.allow_tf32 = matmul + yield + finally: + if cudnn is not None: + torch.backends.cudnn.allow_tf32 = cudnn_old + if matmul is not None: + torch.backends.cuda.matmul.allow_tf32 = matmul_old diff --git a/deforum-stable-diffusion/src/ldm/__pycache__/util.cpython-39.pyc b/deforum-stable-diffusion/src/ldm/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..856982ef419519044d02309efab8c477658e8665 Binary files /dev/null and b/deforum-stable-diffusion/src/ldm/__pycache__/util.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/ldm/data/__init__.py b/deforum-stable-diffusion/src/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deforum-stable-diffusion/src/ldm/data/base.py b/deforum-stable-diffusion/src/ldm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b196c2f7aa583a3e8bc4aad9f943df0c4dae0da7 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/data/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass \ No newline at end of file diff --git a/deforum-stable-diffusion/src/ldm/data/imagenet.py b/deforum-stable-diffusion/src/ldm/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..1c473f9c6965b22315dbb289eff8247c71bdc790 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/deforum-stable-diffusion/src/ldm/data/lsun.py b/deforum-stable-diffusion/src/ldm/data/lsun.py new file mode 100644 index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/deforum-stable-diffusion/src/ldm/data/personalized.py b/deforum-stable-diffusion/src/ldm/data/personalized.py new file mode 100644 index 0000000000000000000000000000000000000000..3c147e8e787b87ad9675b61a021db61d3e07e789 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/data/personalized.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_smallest = [ + 'a photo of a {}', +] + +imagenet_templates_small = [ + 'a photo of a {}', + 'a rendering of a {}', + 'a cropped photo of the {}', + 'the photo of a {}', + 'a photo of a clean {}', + 'a photo of a dirty {}', + 'a dark photo of the {}', + 'a photo of my {}', + 'a photo of the cool {}', + 'a close-up photo of a {}', + 'a bright photo of the {}', + 'a cropped photo of a {}', + 'a photo of the {}', + 'a good photo of the {}', + 'a photo of one {}', + 'a close-up photo of the {}', + 'a rendition of the {}', + 'a photo of the clean {}', + 'a rendition of a {}', + 'a photo of a nice {}', + 'a good photo of a {}', + 'a photo of the nice {}', + 'a photo of the small {}', + 'a photo of the weird {}', + 'a photo of the large {}', + 'a photo of a cool {}', + 'a photo of a small {}', +] + +imagenet_dual_templates_small = [ + 'a photo of a {} with {}', + 'a rendering of a {} with {}', + 'a cropped photo of the {} with {}', + 'the photo of a {} with {}', + 'a photo of a clean {} with {}', + 'a photo of a dirty {} with {}', + 'a dark photo of the {} with {}', + 'a photo of my {} with {}', + 'a photo of the cool {} with {}', + 'a close-up photo of a {} with {}', + 'a bright photo of the {} with {}', + 'a cropped photo of a {} with {}', + 'a photo of the {} with {}', + 'a good photo of the {} with {}', + 'a photo of one {} with {}', + 'a close-up photo of the {} with {}', + 'a rendition of the {} with {}', + 'a photo of the clean {} with {}', + 'a rendition of a {} with {}', + 'a photo of a nice {} with {}', + 'a good photo of a {} with {}', + 'a photo of the nice {} with {}', + 'a photo of the small {} with {}', + 'a photo of the weird {} with {}', + 'a photo of the large {} with {}', + 'a photo of a cool {} with {}', + 'a photo of a small {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + mixing_prob=0.25, + coarse_class_text=None, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + self.mixing_prob = mixing_prob + + self.coarse_class_text = coarse_class_text + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + placeholder_string = self.placeholder_token + if self.coarse_class_text: + placeholder_string = f"{self.coarse_class_text} {placeholder_string}" + + image = image.convert('RGBA') + new_image = Image.new('RGBA', image.size, 'WHITE') + new_image.paste(image, (0, 0), image) + image = new_image.convert('RGB') + + templates = [ + 'a {} portrait of {}', + 'an {} image of {}', + 'a {} pretty picture of {}', + 'a {} clip art picture of {}', + 'an {} illustration of {}', + 'a {} 3D render of {}', + 'a {} {}', + ] + + filename = os.path.basename(self.image_paths[i % self.num_images]) + filename_tokens = os.path.splitext(filename)[0].replace(' ', '-').replace('_', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + text = random.choice(templates).format(' '.join(filename_tokens), self.placeholder_token) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/deforum-stable-diffusion/src/ldm/data/personalized_style.py b/deforum-stable-diffusion/src/ldm/data/personalized_style.py new file mode 100644 index 0000000000000000000000000000000000000000..1fefb6dd34dfc488931794716b8474b199f23dc5 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/data/personalized_style.py @@ -0,0 +1,143 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_small = [ + 'a painting in the style of {}', + 'a rendering in the style of {}', + 'a cropped painting in the style of {}', + 'the painting in the style of {}', + 'a clean painting in the style of {}', + 'a dirty painting in the style of {}', + 'a dark painting in the style of {}', + 'a picture in the style of {}', + 'a cool painting in the style of {}', + 'a close-up painting in the style of {}', + 'a bright painting in the style of {}', + 'a cropped painting in the style of {}', + 'a good painting in the style of {}', + 'a close-up painting in the style of {}', + 'a rendition in the style of {}', + 'a nice painting in the style of {}', + 'a small painting in the style of {}', + 'a weird painting in the style of {}', + 'a large painting in the style of {}', +] + +imagenet_dual_templates_small = [ + 'a painting in the style of {} with {}', + 'a rendering in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'the painting in the style of {} with {}', + 'a clean painting in the style of {} with {}', + 'a dirty painting in the style of {} with {}', + 'a dark painting in the style of {} with {}', + 'a cool painting in the style of {} with {}', + 'a close-up painting in the style of {} with {}', + 'a bright painting in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'a good painting in the style of {} with {}', + 'a painting of one {} in the style of {}', + 'a nice painting in the style of {} with {}', + 'a small painting in the style of {} with {}', + 'a weird painting in the style of {} with {}', + 'a large painting in the style of {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + image = image.convert('RGBA') + new_image = Image.new('RGBA', image.size, 'WHITE') + new_image.paste(image, (0, 0), image) + image = new_image.convert('RGB') + + templates = [ + 'a {} portrait of {}', + 'an {} image of {}', + 'a {} pretty picture of {}', + 'a {} clip art picture of {}', + 'an {} illustration of {}', + 'a {} 3D render of {}', + 'a {} {}', + ] + + filename = os.path.basename(self.image_paths[i % self.num_images]) + filename_tokens = os.path.splitext(filename)[0].replace('_', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + text = random.choice(templates).format(' '.join(filename_tokens), self.placeholder_token) + print(text) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/deforum-stable-diffusion/src/ldm/lr_scheduler.py b/deforum-stable-diffusion/src/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/deforum-stable-diffusion/src/ldm/models/autoencoder.py b/deforum-stable-diffusion/src/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9c4f45498561953b8085981609b2a3298a5473 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/__init__.py b/deforum-stable-diffusion/src/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc b/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f0ddb9fd20efb1d0d9b45a78c0f84931e1125aa Binary files /dev/null and b/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc b/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d31bc70486bb89871f405892ff7c16c6b8e913c Binary files /dev/null and b/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/plms.cpython-39.pyc b/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/plms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be08dc664c1e19820b775e7629c130d2a24ea8eb Binary files /dev/null and b/deforum-stable-diffusion/src/ldm/models/diffusion/__pycache__/plms.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/classifier.py b/deforum-stable-diffusion/src/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/ddim.py b/deforum-stable-diffusion/src/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc2afe296f98a9d0f5513a26533595a200fade3 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/models/diffusion/ddim.py @@ -0,0 +1,246 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +#from tqdm.notebook import tqdm +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(img, pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, img_callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, pred_x0 = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + + if img_callback: img_callback(x_dec, pred_x0, i) + + return x_dec diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/ddpm.py b/deforum-stable-diffusion/src/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..ef05565c5cd9246bcaf3c3576a3e12ae6b207904 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1447 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +#from tqdm.notebook import tqdm +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + #print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=False) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, x0_partial, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, pred_x0 = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, pred_x0, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/deforum-stable-diffusion/src/ldm/models/diffusion/plms.py b/deforum-stable-diffusion/src/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..30396f61708423c780b808219cdb474a17f40b14 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/models/diffusion/plms.py @@ -0,0 +1,237 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +#from tqdm.notebook import tqdm +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(img, pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/deforum-stable-diffusion/src/ldm/modules/attention.py b/deforum-stable-diffusion/src/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..87125ff140a930d890abf8e7689feed65bd07a01 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/attention.py @@ -0,0 +1,291 @@ +import gc +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) + v_in = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/deforum-stable-diffusion/src/ldm/modules/attention_xformers.py b/deforum-stable-diffusion/src/ldm/modules/attention_xformers.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf6ce25cabf2eaa55951760ce6226c90270c0c5 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/attention_xformers.py @@ -0,0 +1,420 @@ +import gc +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Any, Optional +import xformers +import xformers.ops + + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) + v_in = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + AttentionBuilder = MemoryEfficientCrossAttention + self.attn1 = AttentionBuilder(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = AttentionBuilder(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def _set_attention_slice(self, slice_size): + self.attn1._slice_size = slice_size + self.attn2._slice_size = slice_size + + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states + + # def forward(self, x, context=None): + # return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + # def _forward(self, x, context=None): + # x = self.attn1(self.norm1(x)) + x + # x = self.attn2(self.norm2(x), context=context) + x + # x = self.ff(self.norm3(x)) + x + # return x + +class MemoryEfficientCrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + + _, M, K = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + + def forward(self, x, context=None, mask=None): + + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # init the attention op, if required, using the proper dimensions + self._maybe_init(q) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + + + + + return self.to_out(out) + + + + + + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__init__.py b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..413a566171be39f4bed92a66c393938d6ad26b2c Binary files /dev/null and b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..370205174dc2517e41bcf1eb5f19a0a027711281 Binary files /dev/null and b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc differ diff --git a/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/model.py b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e55d7b8ce9569b3fc01cb009f6a577a58d206385 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,907 @@ +# pytorch_diffusion + derived encoder decoder +import gc +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 + + if temb is not None: + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h8 + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + #print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + #print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h1 = self.conv_in(z) + + # middle + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + t = h + h = self.up[i_level].attn[i_block](t) + del t + + if i_level != 0: + t = h + h = self.up[i_level].upsample(t) + del t + + # end + if self.give_pre_end: + return h + + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + + if self.tanh_out: + t = h + h = torch.tanh(t) + del t + + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/openaimodel.py b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf95d1ea8a078dd259915109203789f78f0643a --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,961 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/util.py b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..17f9679a36480999990d2a157f2fa1934a001a33 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,272 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + elif ddim_discr_method == 'fill': + ddim_timesteps = np.linspace(0, num_ddpm_timesteps-1,num_ddim_timesteps+1).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + if ddim_discr_method == 'fill': + steps_out = ddim_timesteps + else: + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/deforum-stable-diffusion/src/ldm/modules/distributions/__init__.py b/deforum-stable-diffusion/src/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deforum-stable-diffusion/src/ldm/modules/distributions/distributions.py b/deforum-stable-diffusion/src/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/deforum-stable-diffusion/src/ldm/modules/ema.py b/deforum-stable-diffusion/src/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/deforum-stable-diffusion/src/ldm/modules/embedding_manager.py b/deforum-stable-diffusion/src/ldm/modules/embedding_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6ae420a9a57b3d8a47a85d4238cdf8eb8d42d5 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/embedding_manager.py @@ -0,0 +1,255 @@ +from cmath import log +import torch +from torch import nn + +import sys + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ['*'] + +PROGRESSIVE_SCALE = 2000 + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'] + """ assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" """ + + return tokens[0, 1] + + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs, + ): + super().__init__() + + self.embedder = embedder + + self.string_to_token_dict = {} + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = ( + nn.ParameterDict() + ) # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr( + embedder, 'tokenizer' + ): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial( + get_clip_token_for_string, embedder.tokenizer + ) + get_embedding_for_tkn = partial( + get_embedding_for_clip_token, + embedder.transformer.text_model.embeddings, + ) + token_dim = 1280 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial( + get_bert_token_for_string, embedder.tknz_fn + ) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn( + init_word_token.cpu() + ) + + token_params = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=True, + ) + self.initial_embeddings[ + placeholder_string + ] = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=False, + ) + else: + token_params = torch.nn.Parameter( + torch.rand( + size=(num_vectors_per_token, token_dim), + requires_grad=True, + ) + ) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for ( + placeholder_string, + placeholder_token, + ) in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[ + placeholder_string + ].to(device) + + if ( + self.max_vectors_per_token == 1 + ): # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where( + tokenized_text == placeholder_token.to(device) + ) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = ( + 1 + self.progressive_counter // PROGRESSIVE_SCALE + ) + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min( + placeholder_embedding.shape[0], max_step_tokens + ) + + placeholder_rows, placeholder_cols = torch.where( + tokenized_text == placeholder_token.to(device) + ) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort( + placeholder_cols, descending=True + ) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat( + [ + tokenized_text[row][:col], + placeholder_token.repeat(num_vectors_for_token).to( + device + ), + tokenized_text[row][col + 1 :], + ], + axis=0, + )[:n] + new_embed_row = torch.cat( + [ + embedded_text[row][:col], + placeholder_embedding[:num_vectors_for_token], + embedded_text[row][col + 1 :], + ], + axis=0, + )[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save( + { + 'string_to_token': self.string_to_token_dict, + 'string_to_param': self.string_to_param_dict, + }, + ckpt_path, + ) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + + def get_embedding_norms_squared(self): + all_params = torch.cat( + list(self.string_to_param_dict.values()), axis=0 + ) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum( + axis=-1 + ) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0.0 + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = ( + loss + + (optimized - coarse) + @ (optimized - coarse).T + / num_embeddings + ) + + return loss diff --git a/deforum-stable-diffusion/src/ldm/modules/embedding_managerbin.py b/deforum-stable-diffusion/src/ldm/modules/embedding_managerbin.py new file mode 100644 index 0000000000000000000000000000000000000000..25df677444e7938d836594ba4b38bbde92c6e671 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/embedding_managerbin.py @@ -0,0 +1,175 @@ +import torch +from torch import nn + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + #assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" + + return tokens[0, 1] + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs + ): + super().__init__() + + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) + token_dim = 768 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + if isinstance(ckpt, nn.ParameterDict): + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + else: + file_token = list(ckpt.keys())[0] + new_token = '*' + + tensor_size = ckpt[file_token].count_nonzero() + newt = ckpt[file_token].reshape(1, tensor_size) + newt = newt.half() + + nparam = nn.Parameter(data = newt, requires_grad=True) + + self.string_to_token_dict = {new_token: torch.tensor(265)} + self.string_to_param_dict = nn.ParameterDict({new_token: nparam}) + + print(f'Added terms: {", ".join(self.string_to_param_dict.keys())}') + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss diff --git a/deforum-stable-diffusion/src/ldm/modules/embedding_managerpt.py b/deforum-stable-diffusion/src/ldm/modules/embedding_managerpt.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6ae420a9a57b3d8a47a85d4238cdf8eb8d42d5 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/embedding_managerpt.py @@ -0,0 +1,255 @@ +from cmath import log +import torch +from torch import nn + +import sys + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ['*'] + +PROGRESSIVE_SCALE = 2000 + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'] + """ assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" """ + + return tokens[0, 1] + + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs, + ): + super().__init__() + + self.embedder = embedder + + self.string_to_token_dict = {} + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = ( + nn.ParameterDict() + ) # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr( + embedder, 'tokenizer' + ): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial( + get_clip_token_for_string, embedder.tokenizer + ) + get_embedding_for_tkn = partial( + get_embedding_for_clip_token, + embedder.transformer.text_model.embeddings, + ) + token_dim = 1280 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial( + get_bert_token_for_string, embedder.tknz_fn + ) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn( + init_word_token.cpu() + ) + + token_params = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=True, + ) + self.initial_embeddings[ + placeholder_string + ] = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=False, + ) + else: + token_params = torch.nn.Parameter( + torch.rand( + size=(num_vectors_per_token, token_dim), + requires_grad=True, + ) + ) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for ( + placeholder_string, + placeholder_token, + ) in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[ + placeholder_string + ].to(device) + + if ( + self.max_vectors_per_token == 1 + ): # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where( + tokenized_text == placeholder_token.to(device) + ) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = ( + 1 + self.progressive_counter // PROGRESSIVE_SCALE + ) + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min( + placeholder_embedding.shape[0], max_step_tokens + ) + + placeholder_rows, placeholder_cols = torch.where( + tokenized_text == placeholder_token.to(device) + ) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort( + placeholder_cols, descending=True + ) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat( + [ + tokenized_text[row][:col], + placeholder_token.repeat(num_vectors_for_token).to( + device + ), + tokenized_text[row][col + 1 :], + ], + axis=0, + )[:n] + new_embed_row = torch.cat( + [ + embedded_text[row][:col], + placeholder_embedding[:num_vectors_for_token], + embedded_text[row][col + 1 :], + ], + axis=0, + )[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save( + { + 'string_to_token': self.string_to_token_dict, + 'string_to_param': self.string_to_param_dict, + }, + ckpt_path, + ) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + + def get_embedding_norms_squared(self): + all_params = torch.cat( + list(self.string_to_param_dict.values()), axis=0 + ) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum( + axis=-1 + ) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0.0 + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = ( + loss + + (optimized - coarse) + @ (optimized - coarse).T + / num_embeddings + ) + + return loss diff --git a/deforum-stable-diffusion/src/ldm/modules/encoders/__init__.py b/deforum-stable-diffusion/src/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deforum-stable-diffusion/src/ldm/modules/encoders/modules.py b/deforum-stable-diffusion/src/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9f179bee547545bed8c963342a61086dc6392372 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/encoders/modules.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=False) diff --git a/deforum-stable-diffusion/src/ldm/modules/image_degradation/__init__.py b/deforum-stable-diffusion/src/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/deforum-stable-diffusion/src/ldm/modules/image_degradation/bsrgan.py b/deforum-stable-diffusion/src/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/deforum-stable-diffusion/src/ldm/modules/image_degradation/bsrgan_light.py b/deforum-stable-diffusion/src/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1f823996bf559e9b015ea9aa2b3cd38dd13af1 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/deforum-stable-diffusion/src/ldm/modules/image_degradation/utils/test.png b/deforum-stable-diffusion/src/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/deforum-stable-diffusion/src/ldm/modules/image_degradation/utils/test.png differ diff --git a/deforum-stable-diffusion/src/ldm/modules/image_degradation/utils_image.py b/deforum-stable-diffusion/src/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/deforum-stable-diffusion/src/ldm/modules/losses/__init__.py b/deforum-stable-diffusion/src/ldm/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/deforum-stable-diffusion/src/ldm/modules/losses/contperceptual.py b/deforum-stable-diffusion/src/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/deforum-stable-diffusion/src/ldm/modules/losses/vqperceptual.py b/deforum-stable-diffusion/src/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/deforum-stable-diffusion/src/ldm/modules/x_transformer.py b/deforum-stable-diffusion/src/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/deforum-stable-diffusion/src/ldm/util.py b/deforum-stable-diffusion/src/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba38853e7a07228cc2c187742b5c45d7359b3f9 --- /dev/null +++ b/deforum-stable-diffusion/src/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/deforum-stable-diffusion/src/midas/base_model.py b/deforum-stable-diffusion/src/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/deforum-stable-diffusion/src/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/deforum-stable-diffusion/src/midas/blocks.py b/deforum-stable-diffusion/src/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/deforum-stable-diffusion/src/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/deforum-stable-diffusion/src/midas/dpt_depth.py b/deforum-stable-diffusion/src/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/deforum-stable-diffusion/src/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/deforum-stable-diffusion/src/midas/midas_net.py b/deforum-stable-diffusion/src/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/deforum-stable-diffusion/src/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/deforum-stable-diffusion/src/midas/midas_net_custom.py b/deforum-stable-diffusion/src/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/deforum-stable-diffusion/src/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/deforum-stable-diffusion/src/midas/transforms.py b/deforum-stable-diffusion/src/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/deforum-stable-diffusion/src/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/deforum-stable-diffusion/src/midas/vit.py b/deforum-stable-diffusion/src/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/deforum-stable-diffusion/src/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/deforum-stable-diffusion/src/model_io.py b/deforum-stable-diffusion/src/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..bca5c177d753ff4c86671b9e34aa30fc212a76fc --- /dev/null +++ b/deforum-stable-diffusion/src/model_io.py @@ -0,0 +1,72 @@ +import os + +import torch + + +def save_weights(model, filename, path="./saved_models"): + if not os.path.isdir(path): + os.makedirs(path) + + fpath = os.path.join(path, filename) + torch.save(model.state_dict(), fpath) + return + + +def save_checkpoint(model, optimizer, epoch, filename, root="./checkpoints"): + if not os.path.isdir(root): + os.makedirs(root) + + fpath = os.path.join(root, filename) + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch + } + , fpath) + + +def load_weights(model, filename, path="./saved_models"): + fpath = os.path.join(path, filename) + state_dict = torch.load(fpath) + model.load_state_dict(state_dict) + return model + + +def load_checkpoint(fpath, model, optimizer=None): + ckpt = torch.load(fpath, map_location='cpu') + if optimizer is None: + optimizer = ckpt.get('optimizer', None) + else: + optimizer.load_state_dict(ckpt['optimizer']) + epoch = ckpt['epoch'] + + if 'model' in ckpt: + ckpt = ckpt['model'] + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + modified = {} # backward compatibility to older naming of architecture blocks + for k, v in load_dict.items(): + if k.startswith('adaptive_bins_layer.embedding_conv.'): + k_ = k.replace('adaptive_bins_layer.embedding_conv.', + 'adaptive_bins_layer.conv3x3.') + modified[k_] = v + # del load_dict[k] + + elif k.startswith('adaptive_bins_layer.patch_transformer.embedding_encoder'): + + k_ = k.replace('adaptive_bins_layer.patch_transformer.embedding_encoder', + 'adaptive_bins_layer.patch_transformer.embedding_convPxP') + modified[k_] = v + # del load_dict[k] + else: + modified[k] = v # else keep the original + + model.load_state_dict(modified) + return model, optimizer, epoch diff --git a/deforum-stable-diffusion/src/py3d_tools.py b/deforum-stable-diffusion/src/py3d_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f299dd02bfd5568413b1732468746921095c3851 --- /dev/null +++ b/deforum-stable-diffusion/src/py3d_tools.py @@ -0,0 +1,1799 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import math +import warnings +from typing import List, Optional, Sequence, Tuple, Union, Any + +import numpy as np +import torch +import torch.nn.functional as F + +import copy +import inspect +import torch.nn as nn + +Device = Union[str, torch.device] + +# Default values for rotation and translation matrices. +_R = torch.eye(3)[None] # (1, 3, 3) +_T = torch.zeros(1, 3) # (1, 3) + + +# Provide get_origin and get_args even in Python 3.7. + +if sys.version_info >= (3, 8, 0): + from typing import get_args, get_origin +elif sys.version_info >= (3, 7, 0): + + def get_origin(cls): # pragma: no cover + return getattr(cls, "__origin__", None) + + def get_args(cls): # pragma: no cover + return getattr(cls, "__args__", None) + + +else: + raise ImportError("This module requires Python 3.7+") + +################################################################ +## ██████╗██╗ █████╗ ███████╗███████╗███████╗███████╗ ## +## ██╔════╝██║ ██╔══██╗██╔════╝██╔════╝██╔════╝██╔════╝ ## +## ██║ ██║ ███████║███████╗███████╗█████╗ ███████╗ ## +## ██║ ██║ ██╔══██║╚════██║╚════██║██╔══╝ ╚════██║ ## +## ╚██████╗███████╗██║ ██║███████║███████║███████╗███████║ ## +## ╚═════╝╚══════╝╚═╝ ╚═╝╚══════╝╚══════╝╚══════╝╚══════╝ ## +################################################################ + +class Transform3d: + """ + A Transform3d object encapsulates a batch of N 3D transformations, and knows + how to transform points and normal vectors. Suppose that t is a Transform3d; + then we can do the following: + + .. code-block:: python + + N = len(t) + points = torch.randn(N, P, 3) + normals = torch.randn(N, P, 3) + points_transformed = t.transform_points(points) # => (N, P, 3) + normals_transformed = t.transform_normals(normals) # => (N, P, 3) + + + BROADCASTING + Transform3d objects supports broadcasting. Suppose that t1 and tN are + Transform3d objects with len(t1) == 1 and len(tN) == N respectively. Then we + can broadcast transforms like this: + + .. code-block:: python + + t1.transform_points(torch.randn(P, 3)) # => (P, 3) + t1.transform_points(torch.randn(1, P, 3)) # => (1, P, 3) + t1.transform_points(torch.randn(M, P, 3)) # => (M, P, 3) + tN.transform_points(torch.randn(P, 3)) # => (N, P, 3) + tN.transform_points(torch.randn(1, P, 3)) # => (N, P, 3) + + + COMBINING TRANSFORMS + Transform3d objects can be combined in two ways: composing and stacking. + Composing is function composition. Given Transform3d objects t1, t2, t3, + the following all compute the same thing: + + .. code-block:: python + + y1 = t3.transform_points(t2.transform_points(t1.transform_points(x))) + y2 = t1.compose(t2).compose(t3).transform_points(x) + y3 = t1.compose(t2, t3).transform_points(x) + + + Composing transforms should broadcast. + + .. code-block:: python + + if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N. + + We can also stack a sequence of Transform3d objects, which represents + composition along the batch dimension; then the following should compute the + same thing. + + .. code-block:: python + + N, M = len(tN), len(tM) + xN = torch.randn(N, P, 3) + xM = torch.randn(M, P, 3) + y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0) + y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0)) + + BUILDING TRANSFORMS + We provide convenience methods for easily building Transform3d objects + as compositions of basic transforms. + + .. code-block:: python + + # Scale by 0.5, then translate by (1, 2, 3) + t1 = Transform3d().scale(0.5).translate(1, 2, 3) + + # Scale each axis by a different amount, then translate, then scale + t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0) + + t3 = t1.compose(t2) + tN = t1.stack(t3, t3) + + + BACKPROP THROUGH TRANSFORMS + When building transforms, we can also parameterize them by Torch tensors; + in this case we can backprop through the construction and application of + Transform objects, so they could be learned via gradient descent or + predicted by a neural network. + + .. code-block:: python + + s1_params = torch.randn(N, requires_grad=True) + t_params = torch.randn(N, 3, requires_grad=True) + s2_params = torch.randn(N, 3, requires_grad=True) + + t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params) + x = torch.randn(N, 3) + y = t.transform_points(x) + loss = compute_loss(y) + loss.backward() + + with torch.no_grad(): + s1_params -= lr * s1_params.grad + t_params -= lr * t_params.grad + s2_params -= lr * s2_params.grad + + CONVENTIONS + We adopt a right-hand coordinate system, meaning that rotation about an axis + with a positive angle results in a counter clockwise rotation. + + This class assumes that transformations are applied on inputs which + are row vectors. The internal representation of the Nx4x4 transformation + matrix is of the form: + + .. code-block:: python + + M = [ + [Rxx, Ryx, Rzx, 0], + [Rxy, Ryy, Rzy, 0], + [Rxz, Ryz, Rzz, 0], + [Tx, Ty, Tz, 1], + ] + + To apply the transformation to points which are row vectors, the M matrix + can be pre multiplied by the points: + + .. code-block:: python + + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * M + + """ + + def __init__( + self, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", + matrix: Optional[torch.Tensor] = None, + ) -> None: + """ + Args: + dtype: The data type of the transformation matrix. + to be used if `matrix = None`. + device: The device for storing the implemented transformation. + If `matrix != None`, uses the device of input `matrix`. + matrix: A tensor of shape (4, 4) or of shape (minibatch, 4, 4) + representing the 4x4 3D transformation matrix. + If `None`, initializes with identity using + the specified `device` and `dtype`. + """ + + if matrix is None: + self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4) + else: + if matrix.ndim not in (2, 3): + raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.') + if matrix.shape[-2] != 4 or matrix.shape[-1] != 4: + raise ValueError( + '"matrix" has to be a tensor of shape (minibatch, 4, 4)' + ) + # set dtype and device from matrix + dtype = matrix.dtype + device = matrix.device + self._matrix = matrix.view(-1, 4, 4) + + self._transforms = [] # store transforms to compose + self._lu = None + self.device = make_device(device) + self.dtype = dtype + + def __len__(self) -> int: + return self.get_matrix().shape[0] + + def __getitem__( + self, index: Union[int, List[int], slice, torch.Tensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(matrix=self.get_matrix()[index]) + + def compose(self, *others: "Transform3d") -> "Transform3d": + """ + Return a new Transform3d representing the composition of self with the + given other transforms, which will be stored as an internal list. + + Args: + *others: Any number of Transform3d objects + + Returns: + A new Transform3d with the stored transforms + """ + out = Transform3d(dtype=self.dtype, device=self.device) + out._matrix = self._matrix.clone() + for other in others: + if not isinstance(other, Transform3d): + msg = "Only possible to compose Transform3d objects; got %s" + raise ValueError(msg % type(other)) + out._transforms = self._transforms + list(others) + return out + + def get_matrix(self) -> torch.Tensor: + """ + Return a matrix which is the result of composing this transform + with others stored in self.transforms. Where necessary transforms + are broadcast against each other. + For example, if self.transforms contains transforms t1, t2, and t3, and + given a set of points x, the following should be true: + + .. code-block:: python + + y1 = t1.compose(t2, t3).transform(x) + y2 = t3.transform(t2.transform(t1.transform(x))) + y1.get_matrix() == y2.get_matrix() + + Returns: + A transformation matrix representing the composed inputs. + """ + composed_matrix = self._matrix.clone() + if len(self._transforms) > 0: + for other in self._transforms: + other_matrix = other.get_matrix() + composed_matrix = _broadcast_bmm(composed_matrix, other_matrix) + return composed_matrix + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + return torch.inverse(self._matrix) + + def inverse(self, invert_composed: bool = False) -> "Transform3d": + """ + Returns a new Transform3d object that represents an inverse of the + current transformation. + + Args: + invert_composed: + - True: First compose the list of stored transformations + and then apply inverse to the result. This is + potentially slower for classes of transformations + with inverses that can be computed efficiently + (e.g. rotations and translations). + - False: Invert the individual stored transformations + independently without composing them. + + Returns: + A new Transform3d object containing the inverse of the original + transformation. + """ + + tinv = Transform3d(dtype=self.dtype, device=self.device) + + if invert_composed: + # first compose then invert + tinv._matrix = torch.inverse(self.get_matrix()) + else: + # self._get_matrix_inverse() implements efficient inverse + # of self._matrix + i_matrix = self._get_matrix_inverse() + + # 2 cases: + if len(self._transforms) > 0: + # a) Either we have a non-empty list of transforms: + # Here we take self._matrix and append its inverse at the + # end of the reverted _transforms list. After composing + # the transformations with get_matrix(), this correctly + # right-multiplies by the inverse of self._matrix + # at the end of the composition. + tinv._transforms = [t.inverse() for t in reversed(self._transforms)] + last = Transform3d(dtype=self.dtype, device=self.device) + last._matrix = i_matrix + tinv._transforms.append(last) + else: + # b) Or there are no stored transformations + # we just set inverted matrix + tinv._matrix = i_matrix + + return tinv + + def stack(self, *others: "Transform3d") -> "Transform3d": + """ + Return a new batched Transform3d representing the batch elements from + self and all the given other transforms all batched together. + + Args: + *others: Any number of Transform3d objects + + Returns: + A new Transform3d. + """ + transforms = [self] + list(others) + matrix = torch.cat([t.get_matrix() for t in transforms], dim=0) + out = Transform3d(dtype=self.dtype, device=self.device) + out._matrix = matrix + return out + + def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor: + """ + Use this transform to transform a set of 3D points. Assumes row major + ordering of the input points. + + Args: + points: Tensor of shape (P, 3) or (N, P, 3) + eps: If eps!=None, the argument is used to clamp the + last coordinate before performing the final division. + The clamping corresponds to: + last_coord := (last_coord.sign() + (last_coord==0)) * + torch.clamp(last_coord.abs(), eps), + i.e. the last coordinates that are exactly 0 will + be clamped to +eps. + + Returns: + points_out: points of shape (N, P, 3) or (P, 3) depending + on the dimensions of the transform + """ + points_batch = points.clone() + if points_batch.dim() == 2: + points_batch = points_batch[None] # (P, 3) -> (1, P, 3) + if points_batch.dim() != 3: + msg = "Expected points to have dim = 2 or dim = 3: got shape %r" + raise ValueError(msg % repr(points.shape)) + + N, P, _3 = points_batch.shape + ones = torch.ones(N, P, 1, dtype=points.dtype, device=points.device) + points_batch = torch.cat([points_batch, ones], dim=2) + + composed_matrix = self.get_matrix() + points_out = _broadcast_bmm(points_batch, composed_matrix) + denom = points_out[..., 3:] # denominator + if eps is not None: + denom_sign = denom.sign() + (denom == 0.0).type_as(denom) + denom = denom_sign * torch.clamp(denom.abs(), eps) + points_out = points_out[..., :3] / denom + + # When transform is (1, 4, 4) and points is (P, 3) return + # points_out of shape (P, 3) + if points_out.shape[0] == 1 and points.dim() == 2: + points_out = points_out.reshape(points.shape) + + return points_out + + def transform_normals(self, normals) -> torch.Tensor: + """ + Use this transform to transform a set of normal vectors. + + Args: + normals: Tensor of shape (P, 3) or (N, P, 3) + + Returns: + normals_out: Tensor of shape (P, 3) or (N, P, 3) depending + on the dimensions of the transform + """ + if normals.dim() not in [2, 3]: + msg = "Expected normals to have dim = 2 or dim = 3: got shape %r" + raise ValueError(msg % (normals.shape,)) + composed_matrix = self.get_matrix() + + # TODO: inverse is bad! Solve a linear system instead + mat = composed_matrix[:, :3, :3] + normals_out = _broadcast_bmm(normals, mat.transpose(1, 2).inverse()) + + # This doesn't pass unit tests. TODO investigate further + # if self._lu is None: + # self._lu = self._matrix[:, :3, :3].transpose(1, 2).lu() + # normals_out = normals.lu_solve(*self._lu) + + # When transform is (1, 4, 4) and normals is (P, 3) return + # normals_out of shape (P, 3) + if normals_out.shape[0] == 1 and normals.dim() == 2: + normals_out = normals_out.reshape(normals.shape) + + return normals_out + + def translate(self, *args, **kwargs) -> "Transform3d": + return self.compose( + Translate(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def scale(self, *args, **kwargs) -> "Transform3d": + return self.compose( + Scale(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def rotate(self, *args, **kwargs) -> "Transform3d": + return self.compose( + Rotate(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d": + return self.compose( + RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs) + ) + + def clone(self) -> "Transform3d": + """ + Deep copy of Transforms object. All internal tensors are cloned + individually. + + Returns: + new Transforms object. + """ + other = Transform3d(dtype=self.dtype, device=self.device) + if self._lu is not None: + other._lu = [elem.clone() for elem in self._lu] + other._matrix = self._matrix.clone() + other._transforms = [t.clone() for t in self._transforms] + return other + + def to( + self, + device: Device, + copy: bool = False, + dtype: Optional[torch.dtype] = None, + ) -> "Transform3d": + """ + Match functionality of torch.Tensor.to() + If copy = True or the self Tensor is on a different device, the + returned tensor is a copy of self with the desired torch.device. + If copy = False and the self Tensor already has the correct torch.device, + then self is returned. + + Args: + device: Device (as str or torch.device) for the new tensor. + copy: Boolean indicator whether or not to clone self. Default False. + dtype: If not None, casts the internal tensor variables + to a given torch.dtype. + + Returns: + Transform3d object. + """ + device_ = make_device(device) + dtype_ = self.dtype if dtype is None else dtype + skip_to = self.device == device_ and self.dtype == dtype_ + + if not copy and skip_to: + return self + + other = self.clone() + + if skip_to: + return other + + other.device = device_ + other.dtype = dtype_ + other._matrix = other._matrix.to(device=device_, dtype=dtype_) + other._transforms = [ + t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms + ] + return other + + def cpu(self) -> "Transform3d": + return self.to("cpu") + + def cuda(self) -> "Transform3d": + return self.to("cuda") + +class Translate(Transform3d): + def __init__( + self, + x, + y=None, + z=None, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + ) -> None: + """ + Create a new Transform3d representing 3D translations. + + Option I: Translate(xyz, dtype=torch.float32, device='cpu') + xyz should be a tensor of shape (N, 3) + + Option II: Translate(x, y, z, dtype=torch.float32, device='cpu') + Here x, y, and z will be broadcast against each other and + concatenated to form the translation. Each can be: + - A python scalar + - A torch scalar + - A 1D torch tensor + """ + xyz = _handle_input(x, y, z, dtype, device, "Translate") + super().__init__(device=xyz.device, dtype=dtype) + N = xyz.shape[0] + + mat = torch.eye(4, dtype=dtype, device=self.device) + mat = mat.view(1, 4, 4).repeat(N, 1, 1) + mat[:, 3, :3] = xyz + self._matrix = mat + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + inv_mask = self._matrix.new_ones([1, 4, 4]) + inv_mask[0, 3, :3] = -1.0 + i_matrix = self._matrix * inv_mask + return i_matrix + +class Rotate(Transform3d): + def __init__( + self, + R: torch.Tensor, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + orthogonal_tol: float = 1e-5, + ) -> None: + """ + Create a new Transform3d representing 3D rotation using a rotation + matrix as the input. + + Args: + R: a tensor of shape (3, 3) or (N, 3, 3) + orthogonal_tol: tolerance for the test of the orthogonality of R + + """ + device_ = get_device(R, device) + super().__init__(device=device_, dtype=dtype) + if R.dim() == 2: + R = R[None] + if R.shape[-2:] != (3, 3): + msg = "R must have shape (3, 3) or (N, 3, 3); got %s" + raise ValueError(msg % repr(R.shape)) + R = R.to(device=device_, dtype=dtype) + _check_valid_rotation_matrix(R, tol=orthogonal_tol) + N = R.shape[0] + mat = torch.eye(4, dtype=dtype, device=device_) + mat = mat.view(1, 4, 4).repeat(N, 1, 1) + mat[:, :3, :3] = R + self._matrix = mat + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + return self._matrix.permute(0, 2, 1).contiguous() + +class TensorAccessor(nn.Module): + """ + A helper class to be used with the __getitem__ method. This can be used for + getting/setting the values for an attribute of a class at one particular + index. This is useful when the attributes of a class are batched tensors + and one element in the batch needs to be modified. + """ + + def __init__(self, class_object, index: Union[int, slice]) -> None: + """ + Args: + class_object: this should be an instance of a class which has + attributes which are tensors representing a batch of + values. + index: int/slice, an index indicating the position in the batch. + In __setattr__ and __getattr__ only the value of class + attributes at this index will be accessed. + """ + self.__dict__["class_object"] = class_object + self.__dict__["index"] = index + + def __setattr__(self, name: str, value: Any): + """ + Update the attribute given by `name` to the value given by `value` + at the index specified by `self.index`. + Args: + name: str, name of the attribute. + value: value to set the attribute to. + """ + v = getattr(self.class_object, name) + if not torch.is_tensor(v): + msg = "Can only set values on attributes which are tensors; got %r" + raise AttributeError(msg % type(v)) + + # Convert the attribute to a tensor if it is not a tensor. + if not torch.is_tensor(value): + value = torch.tensor( + value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad + ) + + # Check the shapes match the existing shape and the shape of the index. + if v.dim() > 1 and value.dim() > 1 and value.shape[1:] != v.shape[1:]: + msg = "Expected value to have shape %r; got %r" + raise ValueError(msg % (v.shape, value.shape)) + if ( + v.dim() == 0 + and isinstance(self.index, slice) + and len(value) != len(self.index) + ): + msg = "Expected value to have len %r; got %r" + raise ValueError(msg % (len(self.index), len(value))) + self.class_object.__dict__[name][self.index] = value + + def __getattr__(self, name: str): + """ + Return the value of the attribute given by "name" on self.class_object + at the index specified in self.index. + Args: + name: string of the attribute name + """ + if hasattr(self.class_object, name): + return self.class_object.__dict__[name][self.index] + else: + msg = "Attribute %s not found on %r" + return AttributeError(msg % (name, self.class_object.__name__)) + +BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray) + +class TensorProperties(nn.Module): + """ + A mix-in class for storing tensors as properties with helper methods. + """ + + def __init__( + self, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", + **kwargs, + ) -> None: + """ + Args: + dtype: data type to set for the inputs + device: Device (as str or torch.device) + kwargs: any number of keyword arguments. Any arguments which are + of type (float/int/list/tuple/tensor/array) are broadcasted and + other keyword arguments are set as attributes. + """ + super().__init__() + self.device = make_device(device) + self._N = 0 + if kwargs is not None: + + # broadcast all inputs which are float/int/list/tuple/tensor/array + # set as attributes anything else e.g. strings, bools + args_to_broadcast = {} + for k, v in kwargs.items(): + if v is None or isinstance(v, (str, bool)): + setattr(self, k, v) + elif isinstance(v, BROADCAST_TYPES): + args_to_broadcast[k] = v + else: + msg = "Arg %s with type %r is not broadcastable" + warnings.warn(msg % (k, type(v))) + + names = args_to_broadcast.keys() + # convert from type dict.values to tuple + values = tuple(v for v in args_to_broadcast.values()) + + if len(values) > 0: + broadcasted_values = convert_to_tensors_and_broadcast( + *values, device=device + ) + + # Set broadcasted values as attributes on self. + for i, n in enumerate(names): + setattr(self, n, broadcasted_values[i]) + if self._N == 0: + self._N = broadcasted_values[i].shape[0] + + def __len__(self) -> int: + return self._N + + def isempty(self) -> bool: + return self._N == 0 + + def __getitem__(self, index: Union[int, slice]) -> TensorAccessor: + """ + Args: + index: an int or slice used to index all the fields. + Returns: + if `index` is an index int/slice return a TensorAccessor class + with getattribute/setattribute methods which return/update the value + at the index in the original class. + """ + if isinstance(index, (int, slice)): + return TensorAccessor(class_object=self, index=index) + + msg = "Expected index of type int or slice; got %r" + raise ValueError(msg % type(index)) + + # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. + def to(self, device: Device = "cpu") -> "TensorProperties": + """ + In place operation to move class properties which are tensors to a + specified device. If self has a property "device", update this as well. + """ + device_ = make_device(device) + for k in dir(self): + v = getattr(self, k) + if k == "device": + setattr(self, k, device_) + if torch.is_tensor(v) and v.device != device_: + setattr(self, k, v.to(device_)) + return self + + def cpu(self) -> "TensorProperties": + return self.to("cpu") + + # pyre-fixme[14]: `cuda` overrides method defined in `Module` inconsistently. + def cuda(self, device: Optional[int] = None) -> "TensorProperties": + return self.to(f"cuda:{device}" if device is not None else "cuda") + + def clone(self, other) -> "TensorProperties": + """ + Update the tensor properties of other with the cloned properties of self. + """ + for k in dir(self): + v = getattr(self, k) + if inspect.ismethod(v) or k.startswith("__"): + continue + if torch.is_tensor(v): + v_clone = v.clone() + else: + v_clone = copy.deepcopy(v) + setattr(other, k, v_clone) + return other + + def gather_props(self, batch_idx) -> "TensorProperties": + """ + This is an in place operation to reformat all tensor class attributes + based on a set of given indices using torch.gather. This is useful when + attributes which are batched tensors e.g. shape (N, 3) need to be + multiplied with another tensor which has a different first dimension + e.g. packed vertices of shape (V, 3). + Example + .. code-block:: python + self.specular_color = (N, 3) tensor of specular colors for each mesh + A lighting calculation may use + .. code-block:: python + verts_packed = meshes.verts_packed() # (V, 3) + To multiply these two tensors the batch dimension needs to be the same. + To achieve this we can do + .. code-block:: python + batch_idx = meshes.verts_packed_to_mesh_idx() # (V) + This gives index of the mesh for each vertex in verts_packed. + .. code-block:: python + self.gather_props(batch_idx) + self.specular_color = (V, 3) tensor with the specular color for + each packed vertex. + torch.gather requires the index tensor to have the same shape as the + input tensor so this method takes care of the reshaping of the index + tensor to use with class attributes with arbitrary dimensions. + Args: + batch_idx: shape (B, ...) where `...` represents an arbitrary + number of dimensions + Returns: + self with all properties reshaped. e.g. a property with shape (N, 3) + is transformed to shape (B, 3). + """ + # Iterate through the attributes of the class which are tensors. + for k in dir(self): + v = getattr(self, k) + if torch.is_tensor(v): + if v.shape[0] > 1: + # There are different values for each batch element + # so gather these using the batch_idx. + # First clone the input batch_idx tensor before + # modifying it. + _batch_idx = batch_idx.clone() + idx_dims = _batch_idx.shape + tensor_dims = v.shape + if len(idx_dims) > len(tensor_dims): + msg = "batch_idx cannot have more dimensions than %s. " + msg += "got shape %r and %s has shape %r" + raise ValueError(msg % (k, idx_dims, k, tensor_dims)) + if idx_dims != tensor_dims: + # To use torch.gather the index tensor (_batch_idx) has + # to have the same shape as the input tensor. + new_dims = len(tensor_dims) - len(idx_dims) + new_shape = idx_dims + (1,) * new_dims + expand_dims = (-1,) + tensor_dims[1:] + _batch_idx = _batch_idx.view(*new_shape) + _batch_idx = _batch_idx.expand(*expand_dims) + + v = v.gather(0, _batch_idx) + setattr(self, k, v) + return self + +class CamerasBase(TensorProperties): + """ + `CamerasBase` implements a base class for all cameras. + For cameras, there are four different coordinate systems (or spaces) + - World coordinate system: This is the system the object lives - the world. + - Camera view coordinate system: This is the system that has its origin on the camera + and the and the Z-axis perpendicular to the image plane. + In PyTorch3D, we assume that +X points left, and +Y points up and + +Z points out from the image plane. + The transformation from world --> view happens after applying a rotation (R) + and translation (T) + - NDC coordinate system: This is the normalized coordinate system that confines + in a volume the rendered part of the object or scene. Also known as view volume. + For square images, given the PyTorch3D convention, (+1, +1, znear) + is the top left near corner, and (-1, -1, zfar) is the bottom right far + corner of the volume. + The transformation from view --> NDC happens after applying the camera + projection matrix (P) if defined in NDC space. + For non square images, we scale the points such that smallest side + has range [-1, 1] and the largest side has range [-u, u], with u > 1. + - Screen coordinate system: This is another representation of the view volume with + the XY coordinates defined in image space instead of a normalized space. + A better illustration of the coordinate systems can be found in + pytorch3d/docs/notes/cameras.md. + It defines methods that are common to all camera models: + - `get_camera_center` that returns the optical center of the camera in + world coordinates + - `get_world_to_view_transform` which returns a 3D transform from + world coordinates to the camera view coordinates (R, T) + - `get_full_projection_transform` which composes the projection + transform (P) with the world-to-view transform (R, T) + - `transform_points` which takes a set of input points in world coordinates and + projects to the space the camera is defined in (NDC or screen) + - `get_ndc_camera_transform` which defines the transform from screen/NDC to + PyTorch3D's NDC space + - `transform_points_ndc` which takes a set of points in world coordinates and + projects them to PyTorch3D's NDC space + - `transform_points_screen` which takes a set of points in world coordinates and + projects them to screen space + For each new camera, one should implement the `get_projection_transform` + routine that returns the mapping from camera view coordinates to camera + coordinates (NDC or screen). + Another useful function that is specific to each camera model is + `unproject_points` which sends points from camera coordinates (NDC or screen) + back to camera view or world coordinates depending on the `world_coordinates` + boolean argument of the function. + """ + + # Used in __getitem__ to index the relevant fields + # When creating a new camera, this should be set in the __init__ + _FIELDS: Tuple[str, ...] = () + + # Names of fields which are a constant property of the whole batch, rather + # than themselves a batch of data. + # When joining objects into a batch, they will have to agree. + _SHARED_FIELDS: Tuple[str, ...] = () + + def get_projection_transform(self): + """ + Calculate the projective transformation matrix. + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + Return: + a `Transform3d` object which represents a batch of projection + matrices of shape (N, 3, 3) + """ + raise NotImplementedError() + + def unproject_points(self, xy_depth: torch.Tensor, **kwargs): + """ + Transform input points from camera coodinates (NDC or screen) + to the world / camera coordinates. + Each of the input points `xy_depth` of shape (..., 3) is + a concatenation of the x, y location and its depth. + For instance, for an input 2D tensor of shape `(num_points, 3)` + `xy_depth` takes the following form: + `xy_depth[i] = [x[i], y[i], depth[i]]`, + for a each point at an index `i`. + The following example demonstrates the relationship between + `transform_points` and `unproject_points`: + .. code-block:: python + cameras = # camera object derived from CamerasBase + xyz = # 3D points of shape (batch_size, num_points, 3) + # transform xyz to the camera view coordinates + xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz) + # extract the depth of each point as the 3rd coord of xyz_cam + depth = xyz_cam[:, :, 2:] + # project the points xyz to the camera + xy = cameras.transform_points(xyz)[:, :, :2] + # append depth to xy + xy_depth = torch.cat((xy, depth), dim=2) + # unproject to the world coordinates + xyz_unproj_world = cameras.unproject_points(xy_depth, world_coordinates=True) + print(torch.allclose(xyz, xyz_unproj_world)) # True + # unproject to the camera coordinates + xyz_unproj = cameras.unproject_points(xy_depth, world_coordinates=False) + print(torch.allclose(xyz_cam, xyz_unproj)) # True + Args: + xy_depth: torch tensor of shape (..., 3). + world_coordinates: If `True`, unprojects the points back to world + coordinates using the camera extrinsics `R` and `T`. + `False` ignores `R` and `T` and unprojects to + the camera view coordinates. + from_ndc: If `False` (default), assumes xy part of input is in + NDC space if self.in_ndc(), otherwise in screen space. If + `True`, assumes xy is in NDC space even if the camera + is defined in screen space. + Returns + new_points: unprojected points with the same shape as `xy_depth`. + """ + raise NotImplementedError() + + def get_camera_center(self, **kwargs) -> torch.Tensor: + """ + Return the 3D location of the camera optical center + in the world coordinates. + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + Setting T here will update the values set in init as this + value may be needed later on in the rendering pipeline e.g. for + lighting calculations. + Returns: + C: a batch of 3D locations of shape (N, 3) denoting + the locations of the center of each camera in the batch. + """ + w2v_trans = self.get_world_to_view_transform(**kwargs) + P = w2v_trans.inverse().get_matrix() + # the camera center is the translation component (the first 3 elements + # of the last row) of the inverted world-to-view + # transform (4x4 RT matrix) + C = P[:, 3, :3] + return C + + def get_world_to_view_transform(self, **kwargs) -> Transform3d: + """ + Return the world-to-view transform. + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + Returns: + A Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + R: torch.Tensor = kwargs.get("R", self.R) + T: torch.Tensor = kwargs.get("T", self.T) + self.R = R # pyre-ignore[16] + self.T = T # pyre-ignore[16] + world_to_view_transform = get_world_to_view_transform(R=R, T=T) + return world_to_view_transform + + def get_full_projection_transform(self, **kwargs) -> Transform3d: + """ + Return the full world-to-camera transform composing the + world-to-view and view-to-camera transforms. + If camera is defined in NDC space, the projected points are in NDC space. + If camera is defined in screen space, the projected points are in screen space. + Args: + **kwargs: parameters for the projection transforms can be passed in + as keyword arguments to override the default values + set in __init__. + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + Returns: + a Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16] + self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16] + world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) + view_to_proj_transform = self.get_projection_transform(**kwargs) + return world_to_view_transform.compose(view_to_proj_transform) + + def transform_points( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transform input points from world to camera space with the + projection matrix defined by the camera. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the camera plane. + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + Returns + new_points: transformed points with the same shape as the input. + """ + world_to_proj_transform = self.get_full_projection_transform(**kwargs) + return world_to_proj_transform.transform_points(points, eps=eps) + + def get_ndc_camera_transform(self, **kwargs) -> Transform3d: + """ + Returns the transform from camera projection space (screen or NDC) to NDC space. + For cameras that can be specified in screen space, this transform + allows points to be converted from screen to NDC space. + The default transform scales the points from [0, W]x[0, H] + to [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image. + This function should be modified per camera definitions if need be, + e.g. for Perspective/Orthographic cameras we provide a custom implementation. + This transform assumes PyTorch3D coordinate system conventions for + both the NDC space and the input points. + This transform interfaces with the PyTorch3D renderer which assumes + input points to the renderer to be in NDC space. + """ + if self.in_ndc(): + return Transform3d(device=self.device, dtype=torch.float32) + else: + # For custom cameras which can be defined in screen space, + # users might might have to implement the screen to NDC transform based + # on the definition of the camera parameters. + # See PerspectiveCameras/OrthographicCameras for an example. + # We don't flip xy because we assume that world points are in + # PyTorch3D coordinates, and thus conversion from screen to ndc + # is a mere scaling from image to [-1, 1] scale. + image_size = kwargs.get("image_size", self.get_image_size()) + return get_screen_to_ndc_transform( + self, with_xyflip=False, image_size=image_size + ) + + def transform_points_ndc( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transforms points from PyTorch3D world/camera space to NDC space. + Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up. + Output points are in NDC space: +X left, +Y up, origin at image center. + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + Returns + new_points: transformed points with the same shape as the input. + """ + world_to_ndc_transform = self.get_full_projection_transform(**kwargs) + if not self.in_ndc(): + to_ndc_transform = self.get_ndc_camera_transform(**kwargs) + world_to_ndc_transform = world_to_ndc_transform.compose(to_ndc_transform) + + return world_to_ndc_transform.transform_points(points, eps=eps) + + def transform_points_screen( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transforms points from PyTorch3D world/camera space to screen space. + Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up. + Output points are in screen space: +X right, +Y down, origin at top left corner. + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + Returns + new_points: transformed points with the same shape as the input. + """ + points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs) + image_size = kwargs.get("image_size", self.get_image_size()) + return get_ndc_to_screen_transform( + self, with_xyflip=True, image_size=image_size + ).transform_points(points_ndc, eps=eps) + + def clone(self): + """ + Returns a copy of `self`. + """ + cam_type = type(self) + other = cam_type(device=self.device) + return super().clone(other) + + def is_perspective(self): + raise NotImplementedError() + + def in_ndc(self): + """ + Specifies whether the camera is defined in NDC space + or in screen (image) space + """ + raise NotImplementedError() + + def get_znear(self): + return self.znear if hasattr(self, "znear") else None + + def get_image_size(self): + """ + Returns the image size, if provided, expected in the form of (height, width) + The image size is used for conversion of projected points to screen coordinates. + """ + return self.image_size if hasattr(self, "image_size") else None + + def __getitem__( + self, index: Union[int, List[int], torch.LongTensor] + ) -> "CamerasBase": + """ + Override for the __getitem__ method in TensorProperties which needs to be + refactored. + Args: + index: an int/list/long tensor used to index all the fields in the cameras given by + self._FIELDS. + Returns: + if `index` is an index int/list/long tensor return an instance of the current + cameras class with only the values at the selected index. + """ + + kwargs = {} + + if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)): + msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r" + raise ValueError(msg % type(index)) + + if isinstance(index, int): + index = [index] + + if max(index) >= len(self): + raise ValueError(f"Index {max(index)} is out of bounds for select cameras") + + for field in self._FIELDS: + val = getattr(self, field, None) + if val is None: + continue + + # e.g. "in_ndc" is set as attribute "_in_ndc" on the class + # but provided as "in_ndc" on initialization + if field.startswith("_"): + field = field[1:] + + if isinstance(val, (str, bool)): + kwargs[field] = val + elif isinstance(val, torch.Tensor): + # In the init, all inputs will be converted to + # tensors before setting as attributes + kwargs[field] = val[index] + else: + raise ValueError(f"Field {field} type is not supported for indexing") + + kwargs["device"] = self.device + return self.__class__(**kwargs) + +class FoVPerspectiveCameras(CamerasBase): + """ + A class which stores a batch of parameters to generate a batch of + projection matrices by specifying the field of view. + The definition of the parameters follow the OpenGL perspective camera. + + The extrinsics of the camera (R and T matrices) can also be set in the + initializer or passed in to `get_full_projection_transform` to get + the full transformation from world -> ndc. + + The `transform_points` method calculates the full world -> ndc transform + and then applies it to the input points. + + The transforms can also be returned separately as Transform3d objects. + + * Setting the Aspect Ratio for Non Square Images * + + If the desired output image size is non square (i.e. a tuple of (H, W) where H != W) + the aspect ratio needs special consideration: There are two aspect ratios + to be aware of: + - the aspect ratio of each pixel + - the aspect ratio of the output image + The `aspect_ratio` setting in the FoVPerspectiveCameras sets the + pixel aspect ratio. When using this camera with the differentiable rasterizer + be aware that in the rasterizer we assume square pixels, but allow + variable image aspect ratio (i.e rectangle images). + + In most cases you will want to set the camera `aspect_ratio=1.0` + (i.e. square pixels) and only vary the output image dimensions in pixels + for rasterization. + """ + + # For __getitem__ + _FIELDS = ( + "K", + "znear", + "zfar", + "aspect_ratio", + "fov", + "R", + "T", + "degrees", + ) + + _SHARED_FIELDS = ("degrees",) + + def __init__( + self, + znear=1.0, + zfar=100.0, + aspect_ratio=1.0, + fov=60.0, + degrees: bool = True, + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", + ) -> None: + """ + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + aspect_ratio: aspect ratio of the image pixels. + 1.0 indicates square pixels. + fov: field of view angle of the camera. + degrees: bool, set to True if fov is specified in degrees. + R: Rotation matrix of shape (N, 3, 3) + T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need znear, zfar, fov, aspect_ratio, degrees + device: Device (as str or torch.device) + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + super().__init__( + device=device, + znear=znear, + zfar=zfar, + aspect_ratio=aspect_ratio, + fov=fov, + R=R, + T=T, + K=K, + ) + + # No need to convert to tensor or broadcast. + self.degrees = degrees + + def compute_projection_matrix( + self, znear, zfar, fov, aspect_ratio, degrees: bool + ) -> torch.Tensor: + """ + Compute the calibration matrix K of shape (N, 4, 4) + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + fov: field of view angle of the camera. + aspect_ratio: aspect ratio of the image pixels. + 1.0 indicates square pixels. + degrees: bool, set to True if fov is specified in degrees. + + Returns: + torch.FloatTensor of the calibration matrix with shape (N, 4, 4) + """ + K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32) + ones = torch.ones((self._N), dtype=torch.float32, device=self.device) + if degrees: + fov = (np.pi / 180) * fov + + if not torch.is_tensor(fov): + fov = torch.tensor(fov, device=self.device) + tanHalfFov = torch.tan((fov / 2)) + max_y = tanHalfFov * znear + min_y = -max_y + max_x = max_y * aspect_ratio + min_x = -max_x + + # NOTE: In OpenGL the projection matrix changes the handedness of the + # coordinate frame. i.e the NDC space positive z direction is the + # camera space negative z direction. This is because the sign of the z + # in the projection matrix is set to -1.0. + # In pytorch3d we maintain a right handed coordinate system throughout + # so the so the z sign is 1.0. + z_sign = 1.0 + + K[:, 0, 0] = 2.0 * znear / (max_x - min_x) + K[:, 1, 1] = 2.0 * znear / (max_y - min_y) + K[:, 0, 2] = (max_x + min_x) / (max_x - min_x) + K[:, 1, 2] = (max_y + min_y) / (max_y - min_y) + K[:, 3, 2] = z_sign * ones + + # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point + # is at the near clipping plane and z = 1 when the point is at the far + # clipping plane. + K[:, 2, 2] = z_sign * zfar / (zfar - znear) + K[:, 2, 3] = -(zfar * znear) / (zfar - znear) + + return K + + def get_projection_transform(self, **kwargs) -> Transform3d: + """ + Calculate the perspective projection matrix with a symmetric + viewing frustrum. Use column major order. + The viewing frustrum will be projected into ndc, s.t. + (max_x, max_y) -> (+1, +1) + (min_x, min_y) -> (-1, -1) + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + + Return: + a Transform3d object which represents a batch of projection + matrices of shape (N, 4, 4) + + .. code-block:: python + + h1 = (max_y + min_y)/(max_y - min_y) + w1 = (max_x + min_x)/(max_x - min_x) + tanhalffov = tan((fov/2)) + s1 = 1/tanhalffov + s2 = 1/(tanhalffov * (aspect_ratio)) + + # To map z to the range [0, 1] use: + f1 = far / (far - near) + f2 = -(far * near) / (far - near) + + # Projection matrix + K = [ + [s1, 0, w1, 0], + [0, s2, h1, 0], + [0, 0, f1, f2], + [0, 0, 1, 0], + ] + """ + K = kwargs.get("K", self.K) + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = self.compute_projection_matrix( + kwargs.get("znear", self.znear), + kwargs.get("zfar", self.zfar), + kwargs.get("fov", self.fov), + kwargs.get("aspect_ratio", self.aspect_ratio), + kwargs.get("degrees", self.degrees), + ) + + # Transpose the projection matrix as PyTorch3D transforms use row vectors. + transform = Transform3d( + matrix=K.transpose(1, 2).contiguous(), device=self.device + ) + return transform + + def unproject_points( + self, + xy_depth: torch.Tensor, + world_coordinates: bool = True, + scaled_depth_input: bool = False, + **kwargs, + ) -> torch.Tensor: + """>! + FoV cameras further allow for passing depth in world units + (`scaled_depth_input=False`) or in the [0, 1]-normalized units + (`scaled_depth_input=True`) + + Args: + scaled_depth_input: If `True`, assumes the input depth is in + the [0, 1]-normalized units. If `False` the input depth is in + the world units. + """ + + # obtain the relevant transformation to ndc + if world_coordinates: + to_ndc_transform = self.get_full_projection_transform() + else: + to_ndc_transform = self.get_projection_transform() + + if scaled_depth_input: + # the input is scaled depth, so we don't have to do anything + xy_sdepth = xy_depth + else: + # parse out important values from the projection matrix + K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix() + # parse out f1, f2 from K_matrix + unsqueeze_shape = [1] * xy_depth.dim() + unsqueeze_shape[0] = K_matrix.shape[0] + f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape) + f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape) + # get the scaled depth + sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3] + # concatenate xy + scaled depth + xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1) + + # unproject with inverse of the projection + unprojection_transform = to_ndc_transform.inverse() + return unprojection_transform.transform_points(xy_sdepth) + + def is_perspective(self): + return True + + def in_ndc(self): + return True + +####################################################################################### +## ██████╗ ███████╗███████╗██╗███╗ ██╗██╗████████╗██╗ ██████╗ ███╗ ██╗███████╗ ## +## ██╔══██╗██╔════╝██╔════╝██║████╗ ██║██║╚══██╔══╝██║██╔═══██╗████╗ ██║██╔════╝ ## +## ██║ ██║█████╗ █████╗ ██║██╔██╗ ██║██║ ██║ ██║██║ ██║██╔██╗ ██║███████╗ ## +## ██║ ██║██╔══╝ ██╔══╝ ██║██║╚██╗██║██║ ██║ ██║██║ ██║██║╚██╗██║╚════██║ ## +## ██████╔╝███████╗██║ ██║██║ ╚████║██║ ██║ ██║╚██████╔╝██║ ╚████║███████║ ## +## ╚═════╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚══════╝ ## +####################################################################################### + +def make_device(device: Device) -> torch.device: + """ + Makes an actual torch.device object from the device specified as + either a string or torch.device object. If the device is `cuda` without + a specific index, the index of the current device is assigned. + Args: + device: Device (as str or torch.device) + Returns: + A matching torch.device object + """ + device = torch.device(device) if isinstance(device, str) else device + if device.type == "cuda" and device.index is None: # pyre-ignore[16] + # If cuda but with no index, then the current cuda device is indicated. + # In that case, we fix to that device + device = torch.device(f"cuda:{torch.cuda.current_device()}") + return device + +def get_device(x, device: Optional[Device] = None) -> torch.device: + """ + Gets the device of the specified variable x if it is a tensor, or + falls back to a default CPU device otherwise. Allows overriding by + providing an explicit device. + Args: + x: a torch.Tensor to get the device from or another type + device: Device (as str or torch.device) to fall back to + Returns: + A matching torch.device object + """ + + # User overrides device + if device is not None: + return make_device(device) + + # Set device based on input tensor + if torch.is_tensor(x): + return x.device + + # Default device is cpu + return torch.device("cpu") + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + +def _broadcast_bmm(a, b) -> torch.Tensor: + """ + Batch multiply two matrices and broadcast if necessary. + + Args: + a: torch tensor of shape (P, K) or (M, P, K) + b: torch tensor of shape (N, K, K) + + Returns: + a and b broadcast multiplied. The output batch dimension is max(N, M). + + To broadcast transforms across a batch dimension if M != N then + expect that either M = 1 or N = 1. The tensor with batch dimension 1 is + expanded to have shape N or M. + """ + if a.dim() == 2: + a = a[None] + if len(a) != len(b): + if not ((len(a) == 1) or (len(b) == 1)): + msg = "Expected batch dim for bmm to be equal or 1; got %r, %r" + raise ValueError(msg % (a.shape, b.shape)) + if len(a) == 1: + a = a.expand(len(b), -1, -1) + if len(b) == 1: + b = b.expand(len(a), -1, -1) + return a.bmm(b) + +def _safe_det_3x3(t: torch.Tensor): + """ + Fast determinant calculation for a batch of 3x3 matrices. + Note, result of this function might not be the same as `torch.det()`. + The differences might be in the last significant digit. + Args: + t: Tensor of shape (N, 3, 3). + Returns: + Tensor of shape (N) with determinants. + """ + + det = ( + t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1]) + - t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2]) + + t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1]) + ) + + return det + +def get_world_to_view_transform( + R: torch.Tensor = _R, T: torch.Tensor = _T +) -> Transform3d: + """ + This function returns a Transform3d representing the transformation + matrix to go from world space to view space by applying a rotation and + a translation. + PyTorch3D uses the same convention as Hartley & Zisserman. + I.e., for camera extrinsic parameters R (rotation) and T (translation), + we map a 3D point `X_world` in world coordinates to + a point `X_cam` in camera coordinates with: + `X_cam = X_world R + T` + Args: + R: (N, 3, 3) matrix representing the rotation. + T: (N, 3) matrix representing the translation. + Returns: + a Transform3d object which represents the composed RT transformation. + """ + # TODO: also support the case where RT is specified as one matrix + # of shape (N, 4, 4). + + if T.shape[0] != R.shape[0]: + msg = "Expected R, T to have the same batch dimension; got %r, %r" + raise ValueError(msg % (R.shape[0], T.shape[0])) + if T.dim() != 2 or T.shape[1:] != (3,): + msg = "Expected T to have shape (N, 3); got %r" + raise ValueError(msg % repr(T.shape)) + if R.dim() != 3 or R.shape[1:] != (3, 3): + msg = "Expected R to have shape (N, 3, 3); got %r" + raise ValueError(msg % repr(R.shape)) + + # Create a Transform3d object + T_ = Translate(T, device=T.device) + R_ = Rotate(R, device=R.device) + return R_.compose(T_) + +def _check_valid_rotation_matrix(R, tol: float = 1e-7) -> None: + """ + Determine if R is a valid rotation matrix by checking it satisfies the + following conditions: + + ``RR^T = I and det(R) = 1`` + + Args: + R: an (N, 3, 3) matrix + + Returns: + None + + Emits a warning if R is an invalid rotation matrix. + """ + N = R.shape[0] + eye = torch.eye(3, dtype=R.dtype, device=R.device) + eye = eye.view(1, 3, 3).expand(N, -1, -1) + orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol) + det_R = _safe_det_3x3(R) + no_distortion = torch.allclose(det_R, torch.ones_like(det_R)) + if not (orthogonal and no_distortion): + msg = "R is not a valid rotation matrix" + warnings.warn(msg) + return + +def format_tensor( + input, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", +) -> torch.Tensor: + """ + Helper function for converting a scalar value to a tensor. + Args: + input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor + dtype: data type for the input + device: Device (as str or torch.device) on which the tensor should be placed. + Returns: + input_vec: torch tensor with optional added batch dimension. + """ + device_ = make_device(device) + if not torch.is_tensor(input): + input = torch.tensor(input, dtype=dtype, device=device_) + + if input.dim() == 0: + input = input.view(1) + + if input.device == device_: + return input + + input = input.to(device=device) + return input + +def convert_to_tensors_and_broadcast( + *args, + dtype: torch.dtype = torch.float32, + device: Device = "cpu", +): + """ + Helper function to handle parsing an arbitrary number of inputs (*args) + which all need to have the same batch dimension. + The output is a list of tensors. + Args: + *args: an arbitrary number of inputs + Each of the values in `args` can be one of the following + - Python scalar + - Torch scalar + - Torch tensor of shape (N, K_i) or (1, K_i) where K_i are + an arbitrary number of dimensions which can vary for each + value in args. In this case each input is broadcast to a + tensor of shape (N, K_i) + dtype: data type to use when creating new tensors. + device: torch device on which the tensors should be placed. + Output: + args: A list of tensors of shape (N, K_i) + """ + # Convert all inputs to tensors with a batch dimension + args_1d = [format_tensor(c, dtype, device) for c in args] + + # Find broadcast size + sizes = [c.shape[0] for c in args_1d] + N = max(sizes) + + args_Nd = [] + for c in args_1d: + if c.shape[0] != 1 and c.shape[0] != N: + msg = "Got non-broadcastable sizes %r" % sizes + raise ValueError(msg) + + # Expand broadcast dim and keep non broadcast dims the same size + expand_sizes = (N,) + (-1,) * len(c.shape[1:]) + args_Nd.append(c.expand(*expand_sizes)) + + return args_Nd + +def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + """ + Helper function for _handle_input. + + Args: + c: Python scalar, torch scalar, or 1D torch tensor + + Returns: + c_vec: 1D torch tensor + """ + if not torch.is_tensor(c): + c = torch.tensor(c, dtype=dtype, device=device) + if c.dim() == 0: + c = c.view(1) + if c.device != device or c.dtype != dtype: + c = c.to(device=device, dtype=dtype) + return c + +def _handle_input( + x, + y, + z, + dtype: torch.dtype, + device: Optional[Device], + name: str, + allow_singleton: bool = False, +) -> torch.Tensor: + """ + Helper function to handle parsing logic for building transforms. The output + is always a tensor of shape (N, 3), but there are several types of allowed + input. + + Case I: Single Matrix + In this case x is a tensor of shape (N, 3), and y and z are None. Here just + return x. + + Case II: Vectors and Scalars + In this case each of x, y, and z can be one of the following + - Python scalar + - Torch scalar + - Torch tensor of shape (N, 1) or (1, 1) + In this case x, y and z are broadcast to tensors of shape (N, 1) + and concatenated to a tensor of shape (N, 3) + + Case III: Singleton (only if allow_singleton=True) + In this case y and z are None, and x can be one of the following: + - Python scalar + - Torch scalar + - Torch tensor of shape (N, 1) or (1, 1) + Here x will be duplicated 3 times, and we return a tensor of shape (N, 3) + + Returns: + xyz: Tensor of shape (N, 3) + """ + device_ = get_device(x, device) + # If x is actually a tensor of shape (N, 3) then just return it + if torch.is_tensor(x) and x.dim() == 2: + if x.shape[1] != 3: + msg = "Expected tensor of shape (N, 3); got %r (in %s)" + raise ValueError(msg % (x.shape, name)) + if y is not None or z is not None: + msg = "Expected y and z to be None (in %s)" % name + raise ValueError(msg) + return x.to(device=device_, dtype=dtype) + + if allow_singleton and y is None and z is None: + y = x + z = x + + # Convert all to 1D tensors + xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]] + + # Broadcast and concatenate + sizes = [c.shape[0] for c in xyz] + N = max(sizes) + for c in xyz: + if c.shape[0] != 1 and c.shape[0] != N: + msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name) + raise ValueError(msg) + xyz = [c.expand(N) for c in xyz] + xyz = torch.stack(xyz, dim=1) + return xyz diff --git a/deforum-stable-diffusion/src/rank_images.py b/deforum-stable-diffusion/src/rank_images.py new file mode 100644 index 0000000000000000000000000000000000000000..35a1994e5051d52c0df8244e74f0d644093b60c4 --- /dev/null +++ b/deforum-stable-diffusion/src/rank_images.py @@ -0,0 +1,69 @@ +import os +from argparse import ArgumentParser +from tqdm import tqdm +from PIL import Image +from torch.nn import functional as F +from torchvision import transforms +from torchvision.transforms import functional as TF +import torch +from simulacra_fit_linear_model import AestheticMeanPredictionLinearModel +from CLIP import clip + +parser = ArgumentParser() +parser.add_argument("directory") +parser.add_argument("-t", "--top-n", default=50) +args = parser.parse_args() + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +clip_model_name = 'ViT-B/16' +clip_model = clip.load(clip_model_name, jit=False, device=device)[0] +clip_model.eval().requires_grad_(False) + +normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + +# 512 is embed dimension for ViT-B/16 CLIP +model = AestheticMeanPredictionLinearModel(512) +model.load_state_dict( + torch.load("models/sac_public_2022_06_29_vit_b_16_linear.pth") +) +model = model.to(device) + +def get_filepaths(parentpath, filepaths): + paths = [] + for path in filepaths: + try: + new_parent = os.path.join(parentpath, path) + paths += get_filepaths(new_parent, os.listdir(new_parent)) + except NotADirectoryError: + paths.append(os.path.join(parentpath, path)) + return paths + +filepaths = get_filepaths(args.directory, os.listdir(args.directory)) +scores = [] +for path in tqdm(filepaths): + # This is obviously a flawed way to check for an image but this is just + # a demo script anyway. + if path[-4:] not in (".png", ".jpg"): + continue + img = Image.open(path).convert('RGB') + img = TF.resize(img, 224, transforms.InterpolationMode.LANCZOS) + img = TF.center_crop(img, (224,224)) + img = TF.to_tensor(img).to(device) + img = normalize(img) + clip_image_embed = F.normalize( + clip_model.encode_image(img[None, ...]).float(), + dim=-1) + score = model(clip_image_embed) + if len(scores) < args.top_n: + scores.append((score.item(),path)) + scores.sort() + else: + if scores[0][0] < score: + scores.append((score.item(),path)) + scores.sort(key=lambda x: x[0]) + scores = scores[1:] + +for score, path in scores: + print(f"{score}: {path}") diff --git a/deforum-stable-diffusion/src/simulacra_compute_embeddings.py b/deforum-stable-diffusion/src/simulacra_compute_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd3bfc1db73542a77ca1f9be438877f314fd1e6 --- /dev/null +++ b/deforum-stable-diffusion/src/simulacra_compute_embeddings.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + +"""Precomputes CLIP embeddings for Simulacra Aesthetic Captions.""" + +import argparse +import os +from pathlib import Path +import sqlite3 + +from PIL import Image + +import torch +from torch import multiprocessing as mp +from torch.utils import data +import torchvision.transforms as transforms +from tqdm import tqdm + +from CLIP import clip + + +class SimulacraDataset(data.Dataset): + """Simulacra dataset + Args: + images_dir: directory + transform: preprocessing and augmentation of the training images + """ + + def __init__(self, images_dir, db, transform=None): + self.images_dir = Path(images_dir) + self.transform = transform + self.conn = sqlite3.connect(db) + self.ratings = [] + for row in self.conn.execute('SELECT generations.id, images.idx, paths.path, AVG(ratings.rating) FROM images JOIN generations ON images.gid=generations.id JOIN ratings ON images.id=ratings.iid JOIN paths ON images.id=paths.iid GROUP BY images.id'): + self.ratings.append(row) + + def __len__(self): + return len(self.ratings) + + def __getitem__(self, key): + gid, idx, filename, rating = self.ratings[key] + image = Image.open(self.images_dir / filename).convert('RGB') + if self.transform: + image = self.transform(image) + return image, torch.tensor(rating) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument('--batch-size', '-bs', type=int, default=10, + help='the CLIP model') + p.add_argument('--clip-model', type=str, default='ViT-B/16', + help='the CLIP model') + p.add_argument('--db', type=str, required=True, + help='the database location') + p.add_argument('--device', type=str, + help='the device to use') + p.add_argument('--images-dir', type=str, required=True, + help='the dataset images directory') + p.add_argument('--num-workers', type=int, default=8, + help='the number of data loader workers') + p.add_argument('--output', type=str, required=True, + help='the output file') + p.add_argument('--start-method', type=str, default='spawn', + choices=['fork', 'forkserver', 'spawn'], + help='the multiprocessing start method') + args = p.parse_args() + + mp.set_start_method(args.start_method) + if args.device: + device = torch.device(device) + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print('Using device:', device) + + clip_model, clip_tf = clip.load(args.clip_model, device=device, jit=False) + clip_model = clip_model.eval().requires_grad_(False) + + dataset = SimulacraDataset(args.images_dir, args.db, transform=clip_tf) + loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers) + + embeds, ratings = [], [] + + for batch in tqdm(loader): + images_batch, ratings_batch = batch + embeds.append(clip_model.encode_image(images_batch.to(device)).cpu()) + ratings.append(ratings_batch.clone()) + + obj = {'clip_model': args.clip_model, + 'embeds': torch.cat(embeds), + 'ratings': torch.cat(ratings)} + + torch.save(obj, args.output) + + +if __name__ == '__main__': + main() diff --git a/deforum-stable-diffusion/src/simulacra_fit_linear_model.py b/deforum-stable-diffusion/src/simulacra_fit_linear_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0a80e77f406ca068fd2912040585f2086e7b5436 --- /dev/null +++ b/deforum-stable-diffusion/src/simulacra_fit_linear_model.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 + +"""Fits a linear aesthetic model to precomputed CLIP embeddings.""" + +import argparse + +import numpy as np +from sklearn.linear_model import Ridge +from sklearn.model_selection import train_test_split +import torch +from torch import nn +from torch.nn import functional as F + + +class AestheticMeanPredictionLinearModel(nn.Module): + def __init__(self, feats_in): + super().__init__() + self.linear = nn.Linear(feats_in, 1) + + def forward(self, input): + x = F.normalize(input, dim=-1) * input.shape[-1] ** 0.5 + return self.linear(x) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument('input', type=str, help='the input feature vectors') + p.add_argument('output', type=str, help='the output model') + p.add_argument('--val-size', type=float, default=0.1, help='the validation set size') + p.add_argument('--seed', type=int, default=0, help='the random seed') + args = p.parse_args() + + train_set = torch.load(args.input, map_location='cpu') + X = F.normalize(train_set['embeds'].float(), dim=-1).numpy() + X *= X.shape[-1] ** 0.5 + y = train_set['ratings'].numpy() + X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=args.val_size, random_state=args.seed) + regression = Ridge() + regression.fit(X_train, y_train) + score_train = regression.score(X_train, y_train) + score_val = regression.score(X_val, y_val) + print(f'Score on train: {score_train:g}') + print(f'Score on val: {score_val:g}') + model = AestheticMeanPredictionLinearModel(X_train.shape[1]) + with torch.no_grad(): + model.linear.weight.copy_(torch.tensor(regression.coef_)) + model.linear.bias.copy_(torch.tensor(regression.intercept_)) + torch.save(model.state_dict(), args.output) + + +if __name__ == '__main__': + main() diff --git a/deforum-stable-diffusion/src/taming/data/ade20k.py b/deforum-stable-diffusion/src/taming/data/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..366dae97207dbb8356598d636e14ad084d45bc76 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/ade20k.py @@ -0,0 +1,124 @@ +import os +import numpy as np +import cv2 +import albumentations +from PIL import Image +from torch.utils.data import Dataset + +from taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/ade20k_examples.txt", + data_root="data/ade20k_images", + segmentation_root="data/ade20k_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=151, shift_segmentation=False) + + +# With semantic map and scene label +class ADE20kBase(Dataset): + def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): + self.split = self.get_split() + self.n_labels = 151 # unknown + 150 + self.data_csv = {"train": "data/ade20k_train.txt", + "validation": "data/ade20k_test.txt"}[self.split] + self.data_root = "data/ade20k_root" + with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: + self.scene_categories = f.read().splitlines() + self.scene_categories = dict(line.split() for line in self.scene_categories) + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, "images", l) + for l in self.image_paths], + "relative_segmentation_path_": [l.replace(".jpg", ".png") + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.data_root, "annotations", + l.replace(".jpg", ".png")) + for l in self.image_paths], + "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] + for l in self.image_paths], + } + + size = None if size is not None and size<=0 else size + self.size = size + if crop_size is None: + self.crop_size = size if size is not None else None + else: + self.crop_size = crop_size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + + if crop_size is not None: + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + segmentation = np.array(segmentation).astype(np.uint8) + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, mask=segmentation) + else: + processed = {"image": image, "mask": segmentation} + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class ADE20kTrain(ADE20kBase): + # default to random_crop=True + def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): + super().__init__(config=config, size=size, random_crop=random_crop, + interpolation=interpolation, crop_size=crop_size) + + def get_split(self): + return "train" + + +class ADE20kValidation(ADE20kBase): + def get_split(self): + return "validation" + + +if __name__ == "__main__": + dset = ADE20kValidation() + ex = dset[0] + for k in ["image", "scene_category", "segmentation"]: + print(type(ex[k])) + try: + print(ex[k].shape) + except: + print(ex[k]) diff --git a/deforum-stable-diffusion/src/taming/data/annotated_objects_coco.py b/deforum-stable-diffusion/src/taming/data/annotated_objects_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..af000ecd943d7b8a85d7eb70195c9ecd10ab5edc --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/annotated_objects_coco.py @@ -0,0 +1,139 @@ +import json +from itertools import chain +from pathlib import Path +from typing import Iterable, Dict, List, Callable, Any +from collections import defaultdict + +from tqdm import tqdm + +from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset +from taming.data.helper_types import Annotation, ImageDescription, Category + +COCO_PATH_STRUCTURE = { + 'train': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_train2017.json', + 'stuff_annotations': 'annotations/stuff_train2017.json', + 'files': 'train2017' + }, + 'validation': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_val2017.json', + 'stuff_annotations': 'annotations/stuff_val2017.json', + 'files': 'val2017' + } +} + + +def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: + return { + str(img['id']): ImageDescription( + id=img['id'], + license=img.get('license'), + file_name=img['file_name'], + coco_url=img['coco_url'], + original_size=(img['width'], img['height']), + date_captured=img.get('date_captured'), + flickr_url=img.get('flickr_url') + ) + for img in description_json + } + + +def load_categories(category_json: Iterable) -> Dict[str, Category]: + return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) + for cat in category_json if cat['name'] != 'other'} + + +def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], + category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: + annotations = defaultdict(list) + total = sum(len(a) for a in annotations_json) + for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): + image_id = str(ann['image_id']) + if image_id not in image_descriptions: + raise ValueError(f'image_id [{image_id}] has no image description.') + category_id = ann['category_id'] + try: + category_no = category_no_for_id(str(category_id)) + except KeyError: + continue + + width, height = image_descriptions[image_id].original_size + bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) + + annotations[image_id].append( + Annotation( + id=ann['id'], + area=bbox[2]*bbox[3], # use bbox area + is_group_of=ann['iscrowd'], + image_id=ann['image_id'], + bbox=bbox, + category_id=str(category_id), + category_no=category_no + ) + ) + return dict(annotations) + + +class AnnotatedObjectsCoco(AnnotatedObjectsDataset): + def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): + """ + @param data_path: is the path to the following folder structure: + coco/ + ├── annotations + │ ├── instances_train2017.json + │ ├── instances_val2017.json + │ ├── stuff_train2017.json + │ └── stuff_val2017.json + ├── train2017 + │ ├── 000000000009.jpg + │ ├── 000000000025.jpg + │ └── ... + ├── val2017 + │ ├── 000000000139.jpg + │ ├── 000000000285.jpg + │ └── ... + @param: split: one of 'train' or 'validation' + @param: desired image size (give square images) + """ + super().__init__(**kwargs) + self.use_things = use_things + self.use_stuff = use_stuff + + with open(self.paths['instances_annotations']) as f: + inst_data_json = json.load(f) + with open(self.paths['stuff_annotations']) as f: + stuff_data_json = json.load(f) + + category_jsons = [] + annotation_jsons = [] + if self.use_things: + category_jsons.append(inst_data_json['categories']) + annotation_jsons.append(inst_data_json['annotations']) + if self.use_stuff: + category_jsons.append(stuff_data_json['categories']) + annotation_jsons.append(stuff_data_json['annotations']) + + self.categories = load_categories(chain(*category_jsons)) + self.filter_categories() + self.setup_category_id_and_number() + + self.image_descriptions = load_image_descriptions(inst_data_json['images']) + annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) + self.annotations = self.filter_object_number(annotations, self.min_object_area, + self.min_objects_per_image, self.max_objects_per_image) + self.image_ids = list(self.annotations.keys()) + self.clean_up_annotations_and_image_descriptions() + + def get_path_structure(self) -> Dict[str, str]: + if self.split not in COCO_PATH_STRUCTURE: + raise ValueError(f'Split [{self.split} does not exist for COCO data.]') + return COCO_PATH_STRUCTURE[self.split] + + def get_image_path(self, image_id: str) -> Path: + return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + # noinspection PyProtectedMember + return self.image_descriptions[image_id]._asdict() diff --git a/deforum-stable-diffusion/src/taming/data/base.py b/deforum-stable-diffusion/src/taming/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e21667df4ce4baa6bb6aad9f8679bd756e2ffdb7 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/base.py @@ -0,0 +1,70 @@ +import bisect +import numpy as np +import albumentations +from PIL import Image +from torch.utils.data import Dataset, ConcatDataset + + +class ConcatDatasetWithIndex(ConcatDataset): + """Modified from original pytorch code to return dataset idx""" + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx], dataset_idx + + +class ImagePaths(Dataset): + def __init__(self, paths, size=None, random_crop=False, labels=None): + self.size = size + self.random_crop = random_crop + + self.labels = dict() if labels is None else labels + self.labels["file_path_"] = paths + self._length = len(paths) + + if self.size is not None and self.size > 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) + if not self.random_crop: + self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) + self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return self._length + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def __getitem__(self, i): + example = dict() + example["image"] = self.preprocess_image(self.labels["file_path_"][i]) + for k in self.labels: + example[k] = self.labels[k][i] + return example + + +class NumpyPaths(ImagePaths): + def preprocess_image(self, image_path): + image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 + image = np.transpose(image, (1,2,0)) + image = Image.fromarray(image, mode="RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image diff --git a/deforum-stable-diffusion/src/taming/data/coco.py b/deforum-stable-diffusion/src/taming/data/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2f7838448cb63dcf96daffe9470d58566d975a --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/coco.py @@ -0,0 +1,176 @@ +import os +import json +import albumentations +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset + +from taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/coco_examples.txt", + data_root="data/coco_images", + segmentation_root="data/coco_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=183, shift_segmentation=True) + + +class CocoBase(Dataset): + """needed for (image, caption, segmentation) pairs""" + def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, + crop_size=None, force_no_crop=False, given_files=None): + self.split = self.get_split() + self.size = size + if crop_size is None: + self.crop_size = size + else: + self.crop_size = crop_size + + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation + if self.onehot and not self.stuffthing: + raise NotImplemented("One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different.") + + data_json = datajson + with open(data_json) as json_file: + self.json_data = json.load(json_file) + self.img_id_to_captions = dict() + self.img_id_to_filepath = dict() + self.img_id_to_segmentation_filepath = dict() + + assert data_json.split("/")[-1] in ["captions_train2017.json", + "captions_val2017.json"] + if self.stuffthing: + self.segmentation_prefix = ( + "data/cocostuffthings/val2017" if + data_json.endswith("captions_val2017.json") else + "data/cocostuffthings/train2017") + else: + self.segmentation_prefix = ( + "data/coco/annotations/stuff_val2017_pixelmaps" if + data_json.endswith("captions_val2017.json") else + "data/coco/annotations/stuff_train2017_pixelmaps") + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in tqdm(imagedirs, desc="ImgToPath"): + self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) + self.img_id_to_captions[imgdir["id"]] = list() + pngfilename = imgdir["file_name"].replace("jpg", "png") + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( + self.segmentation_prefix, pngfilename) + if given_files is not None: + if pngfilename in given_files: + self.labels["image_ids"].append(imgdir["id"]) + else: + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in tqdm(capdirs, desc="ImgToCaptions"): + # there are in average 5 captions per image + self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + if self.split=="validation": + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = albumentations.Compose( + [self.rescaler, self.cropper], + additional_targets={"segmentation": "image"}) + if force_no_crop: + self.rescaler = albumentations.Resize(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose( + [self.rescaler], + additional_targets={"segmentation": "image"}) + + def __len__(self): + return len(self.labels["image_ids"]) + + def preprocess_image(self, image_path, segmentation_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + + segmentation = Image.open(segmentation_path) + if not self.onehot and not segmentation.mode == "RGB": + segmentation = segmentation.convert("RGB") + segmentation = np.array(segmentation).astype(np.uint8) + if self.onehot: + assert self.stuffthing + # stored in caffe format: unlabeled==255. stuff and thing from + # 0-181. to be compatible with the labels in + # https://github.com/nightrome/cocostuff/blob/master/labels.txt + # we shift stuffthing one to the right and put unlabeled in zero + # as long as segmentation is uint8 shifting to right handles the + # latter too + assert segmentation.dtype == np.uint8 + segmentation = segmentation + 1 + + processed = self.preprocessor(image=image, segmentation=segmentation) + image, segmentation = processed["image"], processed["segmentation"] + image = (image / 127.5 - 1.0).astype(np.float32) + + if self.onehot: + assert segmentation.dtype == np.uint8 + # make it one hot + n_labels = 183 + flatseg = np.ravel(segmentation) + onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) + onehot[np.arange(flatseg.size), flatseg] = True + onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) + segmentation = onehot + else: + segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) + return image, segmentation + + def __getitem__(self, i): + img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] + seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] + image, segmentation = self.preprocess_image(img_path, seg_path) + captions = self.img_id_to_captions[self.labels["image_ids"][i]] + # randomly draw one of all available captions per image + caption = captions[np.random.randint(0, len(captions))] + example = {"image": image, + "caption": [str(caption[0])], + "segmentation": segmentation, + "img_path": img_path, + "seg_path": seg_path, + "filename_": img_path.split(os.sep)[-1] + } + return example + + +class CocoImagesAndCaptionsTrain(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): + super().__init__(size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + + def get_split(self): + return "train" + + +class CocoImagesAndCaptionsValidation(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None): + super().__init__(size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files) + + def get_split(self): + return "validation" diff --git a/deforum-stable-diffusion/src/taming/data/conditional_builder/objects_bbox.py b/deforum-stable-diffusion/src/taming/data/conditional_builder/objects_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..15881e76b7ab2a914df8f2dfe08ae4f0c6c511b5 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/conditional_builder/objects_bbox.py @@ -0,0 +1,60 @@ +from itertools import cycle +from typing import List, Tuple, Callable, Optional + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + +from taming.data.helper_types import BoundingBox, Annotation +from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ + pad_list, get_plot_font_size, absolute_bbox + + +class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): + @property + def object_descriptor_length(self) -> int: + return 3 + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_triples = [ + (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) + for ann in annotations + ] + empty_triple = (self.none, self.none, self.none) + object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) + return object_triples + + def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + object_triples = grouper(conditional_list, 3) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) + for object_triple in object_triples if object_triple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + font = ImageFont.truetype( + "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", + size=get_plot_font_size(font_size, figure_size) + ) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): + annotation = self.representation_to_annotation(representation) + class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) + bbox = absolute_bbox(bbox, width, height) + draw.rectangle(bbox, outline=color, width=line_width) + draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. diff --git a/deforum-stable-diffusion/src/taming/data/conditional_builder/objects_center_points.py b/deforum-stable-diffusion/src/taming/data/conditional_builder/objects_center_points.py new file mode 100644 index 0000000000000000000000000000000000000000..9a480329cc47fb38a7b8729d424e092b77d40749 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/conditional_builder/objects_center_points.py @@ -0,0 +1,168 @@ +import math +import random +import warnings +from itertools import cycle +from typing import List, Optional, Tuple, Callable + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ + additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ + absolute_bbox, rescale_annotations +from taming.data.helper_types import BoundingBox, Annotation +from taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + + +class ObjectsCenterPointsConditionalBuilder: + def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, + use_group_parameter: bool, use_additional_parameters: bool): + self.no_object_classes = no_object_classes + self.no_max_objects = no_max_objects + self.no_tokens = no_tokens + self.encode_crop = encode_crop + self.no_sections = int(math.sqrt(self.no_tokens)) + self.use_group_parameter = use_group_parameter + self.use_additional_parameters = use_additional_parameters + + @property + def none(self) -> int: + return self.no_tokens - 1 + + @property + def object_descriptor_length(self) -> int: + return 2 + + @property + def embedding_dim(self) -> int: + extra_length = 2 if self.encode_crop else 0 + return self.no_max_objects * self.object_descriptor_length + extra_length + + def tokenize_coordinates(self, x: float, y: float) -> int: + """ + Express 2d coordinates with one number. + Example: assume self.no_tokens = 16, then no_sections = 4: + 0 0 0 0 + 0 0 # 0 + 0 0 0 0 + 0 0 0 x + Then the # position corresponds to token 6, the x position to token 15. + @param x: float in [0, 1] + @param y: float in [0, 1] + @return: discrete tokenized coordinate + """ + x_discrete = int(round(x * (self.no_sections - 1))) + y_discrete = int(round(y * (self.no_sections - 1))) + return y_discrete * self.no_sections + x_discrete + + def coordinates_from_token(self, token: int) -> (float, float): + x = token % self.no_sections + y = token // self.no_sections + return x / (self.no_sections - 1), y / (self.no_sections - 1) + + def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: + x0, y0 = self.coordinates_from_token(token1) + x1, y1 = self.coordinates_from_token(token2) + return x0, y0, x1 - x0, y1 - y0 + + def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: + return self.tokenize_coordinates(bbox[0], bbox[1]), \ + self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) + + def inverse_build(self, conditional: LongTensor) \ + -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + table_of_content = grouper(conditional_list, self.object_descriptor_length) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_tuple[0], self.coordinates_from_token(object_tuple[1])) + for object_tuple in table_of_content if object_tuple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + circle_size = get_circle_size(figure_size) + font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', + size=get_plot_font_size(font_size, figure_size)) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): + x_abs, y_abs = x * width, y * height + ann = self.representation_to_annotation(representation) + label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) + ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] + draw.ellipse(ellipse_bbox, fill=color, width=0) + draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. + + def object_representation(self, annotation: Annotation) -> int: + modifier = 0 + if self.use_group_parameter: + modifier |= 1 * (annotation.is_group_of is True) + if self.use_additional_parameters: + modifier |= 2 * (annotation.is_occluded is True) + modifier |= 4 * (annotation.is_depiction is True) + modifier |= 8 * (annotation.is_inside is True) + return annotation.category_no + self.no_object_classes * modifier + + def representation_to_annotation(self, representation: int) -> Annotation: + category_no = representation % self.no_object_classes + modifier = representation // self.no_object_classes + # noinspection PyTypeChecker + return Annotation( + area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, + category_no=category_no, + is_group_of=bool((modifier & 1) * self.use_group_parameter), + is_occluded=bool((modifier & 2) * self.use_additional_parameters), + is_depiction=bool((modifier & 4) * self.use_additional_parameters), + is_inside=bool((modifier & 8) * self.use_additional_parameters) + ) + + def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: + return list(self.token_pair_from_bbox(crop_coordinates)) + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_tuples = [ + (self.object_representation(a), + self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) + for a in annotations + ] + empty_tuple = (self.none, self.none) + object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) + return object_tuples + + def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ + -> LongTensor: + if len(annotations) == 0: + warnings.warn('Did not receive any annotations.') + if len(annotations) > self.no_max_objects: + warnings.warn('Received more annotations than allowed.') + annotations = annotations[:self.no_max_objects] + + if not crop_coordinates: + crop_coordinates = FULL_CROP + + random.shuffle(annotations) + annotations = filter_annotations(annotations, crop_coordinates) + if self.encode_crop: + annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) + if horizontal_flip: + crop_coordinates = horizontally_flip_bbox(crop_coordinates) + extra = self._crop_encoder(crop_coordinates) + else: + annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) + extra = [] + + object_tuples = self._make_object_descriptors(annotations) + flattened = [token for tuple_ in object_tuples for token in tuple_] + extra + assert len(flattened) == self.embedding_dim + assert all(0 <= value < self.no_tokens for value in flattened) + return LongTensor(flattened) diff --git a/deforum-stable-diffusion/src/taming/data/conditional_builder/utils.py b/deforum-stable-diffusion/src/taming/data/conditional_builder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ee175f2e05a80dbc71c22acbecb22dddadbb42 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/conditional_builder/utils.py @@ -0,0 +1,105 @@ +import importlib +from typing import List, Any, Tuple, Optional + +from taming.data.helper_types import BoundingBox, Annotation + +# source: seaborn, color palette tab10 +COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), + (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] +BLACK = (0, 0, 0) +GRAY_75 = (63, 63, 63) +GRAY_50 = (127, 127, 127) +GRAY_25 = (191, 191, 191) +WHITE = (255, 255, 255) +FULL_CROP = (0., 0., 1., 1.) + + +def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: + """ + Give intersection area of two rectangles. + @param rectangle1: (x0, y0, w, h) of first rectangle + @param rectangle2: (x0, y0, w, h) of second rectangle + """ + rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] + rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] + x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) + y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) + return x_overlap * y_overlap + + +def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: + return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] + + +def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: + bbox = relative_bbox + bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height + return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + + +def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: + return list_ + [pad_element for _ in range(pad_to_length - len(list_))] + + +def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ + List[Annotation]: + def clamp(x: float): + return max(min(x, 1.), 0.) + + def rescale_bbox(bbox: BoundingBox) -> BoundingBox: + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + if flip: + x0 = 1 - (x0 + w) + return x0, y0, w, h + + return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] + + +def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: + return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] + + +def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: + sl = slice(1) if short else slice(None) + string = '' + if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): + return string + if annotation.is_group_of: + string += 'group'[sl] + ',' + if annotation.is_occluded: + string += 'occluded'[sl] + ',' + if annotation.is_depiction: + string += 'depiction'[sl] + ',' + if annotation.is_inside: + string += 'inside'[sl] + return '(' + string.strip(",") + ')' + + +def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: + if font_size is None: + font_size = 10 + if max(figure_size) >= 256: + font_size = 12 + if max(figure_size) >= 512: + font_size = 15 + return font_size + + +def get_circle_size(figure_size: Tuple[int, int]) -> int: + circle_size = 2 + if max(figure_size) >= 256: + circle_size = 3 + if max(figure_size) >= 512: + circle_size = 4 + return circle_size + + +def load_object_from_string(object_string: str) -> Any: + """ + Source: https://stackoverflow.com/a/10773699 + """ + module_name, class_name = object_string.rsplit(".", 1) + return getattr(importlib.import_module(module_name), class_name) diff --git a/deforum-stable-diffusion/src/taming/data/helper_types.py b/deforum-stable-diffusion/src/taming/data/helper_types.py new file mode 100644 index 0000000000000000000000000000000000000000..fb51e301da08602cfead5961c4f7e1d89f6aba79 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/helper_types.py @@ -0,0 +1,49 @@ +from typing import Dict, Tuple, Optional, NamedTuple, Union +from PIL.Image import Image as pil_image +from torch import Tensor + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +Image = Union[Tensor, pil_image] +BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h +CropMethodType = Literal['none', 'random', 'center', 'random-2d'] +SplitType = Literal['train', 'validation', 'test'] + + +class ImageDescription(NamedTuple): + id: int + file_name: str + original_size: Tuple[int, int] # w, h + url: Optional[str] = None + license: Optional[int] = None + coco_url: Optional[str] = None + date_captured: Optional[str] = None + flickr_url: Optional[str] = None + flickr_id: Optional[str] = None + coco_id: Optional[str] = None + + +class Category(NamedTuple): + id: str + super_category: Optional[str] + name: str + + +class Annotation(NamedTuple): + area: float + image_id: str + bbox: BoundingBox + category_no: int + category_id: str + id: Optional[int] = None + source: Optional[str] = None + confidence: Optional[float] = None + is_group_of: Optional[bool] = None + is_truncated: Optional[bool] = None + is_occluded: Optional[bool] = None + is_depiction: Optional[bool] = None + is_inside: Optional[bool] = None + segmentation: Optional[Dict] = None diff --git a/deforum-stable-diffusion/src/taming/data/image_transforms.py b/deforum-stable-diffusion/src/taming/data/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..657ac332174e0ac72f68315271ffbd757b771a0f --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/image_transforms.py @@ -0,0 +1,132 @@ +import random +import warnings +from typing import Union + +import torch +from torch import Tensor +from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor +from torchvision.transforms.functional import _get_image_size as get_image_size + +from taming.data.helper_types import BoundingBox, Image + +pil_to_tensor = PILToTensor() + + +def convert_pil_to_tensor(image: Image) -> Tensor: + with warnings.catch_warnings(): + # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 + warnings.simplefilter("ignore") + return pil_to_tensor(image) + + +class RandomCrop1dReturnCoordinates(RandomCrop): + def forward(self, img: Image) -> (BoundingBox, Image): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + width, height = get_image_size(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h + return bbox, F.crop(img, i, j, h, w) + + +class Random2dCropReturnCoordinates(torch.nn.Module): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + + def __init__(self, min_size: int): + super().__init__() + self.min_size = min_size + + def forward(self, img: Image) -> (BoundingBox, Image): + width, height = get_image_size(img) + max_size = min(width, height) + if max_size <= self.min_size: + size = max_size + else: + size = random.randint(self.min_size, max_size) + top = random.randint(0, height - size) + left = random.randint(0, width - size) + bbox = left / width, top / height, size / width, size / height + return bbox, F.crop(img, top, left, size, size) + + +class CenterCropReturnCoordinates(CenterCrop): + @staticmethod + def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: + if width > height: + w = height / width + h = 1.0 + x0 = 0.5 - w / 2 + y0 = 0. + else: + w = 1.0 + h = width / height + x0 = 0. + y0 = 0.5 - h / 2 + return x0, y0, w, h + + def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + width, height = get_image_size(img) + return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) + + +class RandomHorizontalFlipReturn(RandomHorizontalFlip): + def forward(self, img: Image) -> (bool, Image): + """ + Additionally to flipping, returns a boolean whether it was flipped or not. + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + flipped: whether the image was flipped or not + PIL Image or Tensor: Randomly flipped image. + + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + if torch.rand(1) < self.p: + return True, F.hflip(img) + return False, img diff --git a/deforum-stable-diffusion/src/taming/data/imagenet.py b/deforum-stable-diffusion/src/taming/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..9a02ec44ba4af9e993f58c91fa43482a4ecbe54c --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/imagenet.py @@ -0,0 +1,558 @@ +import os, tarfile, glob, shutil +import yaml +import numpy as np +from tqdm import tqdm +from PIL import Image +import albumentations +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from taming.data.base import ImagePaths +from taming.util import download, retrieve +import taming.data.utils as bdu + + +def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): + synsets = [] + with open(path_to_yaml) as f: + di2s = yaml.load(f) + for idx in indices: + synsets.append(str(di2s[idx])) + print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) + return synsets + + +def str_to_indices(string): + """Expects a string in the format '32-123, 256, 280-321'""" + assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) + subs = string.split(",") + indices = [] + for sub in subs: + subsubs = sub.split("-") + assert len(subsubs) > 0 + if len(subsubs) == 1: + indices.append(int(subsubs[0])) + else: + rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] + indices.extend(rang) + return sorted(indices) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + self.class_labels = [class_dict[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + self.data = ImagePaths(self.abspaths, + labels=labels, + size=retrieve(self.config, "size", default=0), + random_crop=self.random_crop) + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +def get_preprocessor(size=None, random_crop=False, additional_targets=None, + crop_size=None): + if size is not None and size > 0: + transforms = list() + rescaler = albumentations.SmallestMaxSize(max_size = size) + transforms.append(rescaler) + if not random_crop: + cropper = albumentations.CenterCrop(height=size,width=size) + transforms.append(cropper) + else: + cropper = albumentations.RandomCrop(height=size,width=size) + transforms.append(cropper) + flipper = albumentations.HorizontalFlip() + transforms.append(flipper) + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + elif crop_size is not None and crop_size > 0: + if not random_crop: + cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + else: + cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + transforms = [cropper] + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + preprocessor = lambda **kwargs: kwargs + return preprocessor + + +def rgba_to_depth(x): + assert x.dtype == np.uint8 + assert len(x.shape) == 3 and x.shape[2] == 4 + y = x.copy() + y.dtype = np.float32 + y = y.reshape(x.shape[:2]) + return np.ascontiguousarray(y) + + +class BaseWithDepth(Dataset): + DEFAULT_DEPTH_ROOT="data/imagenet_depth" + + def __init__(self, config=None, size=None, random_crop=False, + crop_size=None, root=None): + self.config = config + self.base_dset = self.get_base_dset() + self.preprocessor = get_preprocessor( + size=size, + crop_size=crop_size, + random_crop=random_crop, + additional_targets={"depth": "image"}) + self.crop_size = crop_size + if self.crop_size is not None: + self.rescaler = albumentations.Compose( + [albumentations.SmallestMaxSize(max_size = self.crop_size)], + additional_targets={"depth": "image"}) + if root is not None: + self.DEFAULT_DEPTH_ROOT = root + + def __len__(self): + return len(self.base_dset) + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = self.base_dset[i] + e["depth"] = self.preprocess_depth(self.get_depth_path(e)) + # up if necessary + h,w,c = e["image"].shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + out = self.rescaler(image=e["image"], depth=e["depth"]) + e["image"] = out["image"] + e["depth"] = out["depth"] + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +class ImageNetTrainWithDepth(BaseWithDepth): + # default to random_crop=True + def __init__(self, random_crop=True, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(random_crop=random_crop, **kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetTrain() + else: + return ImageNetTrain({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) + return fid + + +class ImageNetValidationWithDepth(BaseWithDepth): + def __init__(self, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(**kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetValidation() + else: + return ImageNetValidation({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) + return fid + + +class RINTrainWithDepth(ImageNetTrainWithDepth): + def __init__(self, config=None, size=None, random_crop=True, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class RINValidationWithDepth(ImageNetValidationWithDepth): + def __init__(self, config=None, size=None, random_crop=False, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class DRINExamples(Dataset): + def __init__(self): + self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) + with open("data/drin_examples.txt", "r") as f: + relpaths = f.read().splitlines() + self.image_paths = [os.path.join("data/drin_images", + relpath) for relpath in relpaths] + self.depth_paths = [os.path.join("data/drin_depth", + relpath.replace(".JPEG", ".png")) for relpath in relpaths] + + def __len__(self): + return len(self.image_paths) + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = dict() + e["image"] = self.preprocess_image(self.image_paths[i]) + e["depth"] = self.preprocess_depth(self.depth_paths[i]) + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +def imscale(x, factor, keepshapes=False, keepmode="bicubic"): + if factor is None or factor==1: + return x + + dtype = x.dtype + assert dtype in [np.float32, np.float64] + assert x.min() >= -1 + assert x.max() <= 1 + + keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC}[keepmode] + + lr = (x+1.0)*127.5 + lr = lr.clip(0,255).astype(np.uint8) + lr = Image.fromarray(lr) + + h, w, _ = x.shape + nh = h//factor + nw = w//factor + assert nh > 0 and nw > 0, (nh, nw) + + lr = lr.resize((nw,nh), Image.BICUBIC) + if keepshapes: + lr = lr.resize((w,h), keepmode) + lr = np.array(lr)/127.5-1.0 + lr = lr.astype(dtype) + + return lr + + +class ImageNetScale(Dataset): + def __init__(self, size=None, crop_size=None, random_crop=False, + up_factor=None, hr_factor=None, keep_mode="bicubic"): + self.base = self.get_base() + + self.size = size + self.crop_size = crop_size if crop_size is not None else self.size + self.random_crop = random_crop + self.up_factor = up_factor + self.hr_factor = hr_factor + self.keep_mode = keep_mode + + transforms = list() + + if self.size is not None and self.size > 0: + rescaler = albumentations.SmallestMaxSize(max_size = self.size) + self.rescaler = rescaler + transforms.append(rescaler) + + if self.crop_size is not None and self.crop_size > 0: + if len(transforms) == 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) + + if not self.random_crop: + cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) + else: + cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) + transforms.append(cropper) + + if len(transforms) > 0: + if self.up_factor is not None: + additional_targets = {"lr": "image"} + else: + additional_targets = None + self.preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + # adjust resolution + image = imscale(image, self.hr_factor, keepshapes=False) + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + if self.up_factor is None: + image = self.preprocessor(image=image)["image"] + example["image"] = image + else: + lr = imscale(image, self.up_factor, keepshapes=True, + keepmode=self.keep_mode) + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + +class ImageNetScaleTrain(ImageNetScale): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetScaleValidation(ImageNetScale): + def get_base(self): + return ImageNetValidation() + + +from skimage.feature import canny +from skimage.color import rgb2gray + + +class ImageNetEdges(ImageNetScale): + def __init__(self, up_factor=1, **kwargs): + super().__init__(up_factor=1, **kwargs) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + + lr = canny(rgb2gray(image), sigma=2) + lr = lr.astype(np.float32) + lr = lr[:,:,None][:,:,[0,0,0]] + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + + +class ImageNetEdgesTrain(ImageNetEdges): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetEdgesValidation(ImageNetEdges): + def get_base(self): + return ImageNetValidation() diff --git a/deforum-stable-diffusion/src/taming/data/open_images_helper.py b/deforum-stable-diffusion/src/taming/data/open_images_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..8feb7c6e705fc165d2983303192aaa88f579b243 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/open_images_helper.py @@ -0,0 +1,379 @@ +open_images_unify_categories_for_coco = { + '/m/03bt1vf': '/m/01g317', + '/m/04yx4': '/m/01g317', + '/m/05r655': '/m/01g317', + '/m/01bl7v': '/m/01g317', + '/m/0cnyhnx': '/m/01xq0k1', + '/m/01226z': '/m/018xm', + '/m/05ctyq': '/m/018xm', + '/m/058qzx': '/m/04ctx', + '/m/06pcq': '/m/0l515', + '/m/03m3pdh': '/m/02crq1', + '/m/046dlr': '/m/01x3z', + '/m/0h8mzrc': '/m/01x3z', +} + + +top_300_classes_plus_coco_compatibility = [ + ('Man', 1060962), + ('Clothing', 986610), + ('Tree', 748162), + ('Woman', 611896), + ('Person', 610294), + ('Human face', 442948), + ('Girl', 175399), + ('Building', 162147), + ('Car', 159135), + ('Plant', 155704), + ('Human body', 137073), + ('Flower', 133128), + ('Window', 127485), + ('Human arm', 118380), + ('House', 114365), + ('Wheel', 111684), + ('Suit', 99054), + ('Human hair', 98089), + ('Human head', 92763), + ('Chair', 88624), + ('Boy', 79849), + ('Table', 73699), + ('Jeans', 57200), + ('Tire', 55725), + ('Skyscraper', 53321), + ('Food', 52400), + ('Footwear', 50335), + ('Dress', 50236), + ('Human leg', 47124), + ('Toy', 46636), + ('Tower', 45605), + ('Boat', 43486), + ('Land vehicle', 40541), + ('Bicycle wheel', 34646), + ('Palm tree', 33729), + ('Fashion accessory', 32914), + ('Glasses', 31940), + ('Bicycle', 31409), + ('Furniture', 30656), + ('Sculpture', 29643), + ('Bottle', 27558), + ('Dog', 26980), + ('Snack', 26796), + ('Human hand', 26664), + ('Bird', 25791), + ('Book', 25415), + ('Guitar', 24386), + ('Jacket', 23998), + ('Poster', 22192), + ('Dessert', 21284), + ('Baked goods', 20657), + ('Drink', 19754), + ('Flag', 18588), + ('Houseplant', 18205), + ('Tableware', 17613), + ('Airplane', 17218), + ('Door', 17195), + ('Sports uniform', 17068), + ('Shelf', 16865), + ('Drum', 16612), + ('Vehicle', 16542), + ('Microphone', 15269), + ('Street light', 14957), + ('Cat', 14879), + ('Fruit', 13684), + ('Fast food', 13536), + ('Animal', 12932), + ('Vegetable', 12534), + ('Train', 12358), + ('Horse', 11948), + ('Flowerpot', 11728), + ('Motorcycle', 11621), + ('Fish', 11517), + ('Desk', 11405), + ('Helmet', 10996), + ('Truck', 10915), + ('Bus', 10695), + ('Hat', 10532), + ('Auto part', 10488), + ('Musical instrument', 10303), + ('Sunglasses', 10207), + ('Picture frame', 10096), + ('Sports equipment', 10015), + ('Shorts', 9999), + ('Wine glass', 9632), + ('Duck', 9242), + ('Wine', 9032), + ('Rose', 8781), + ('Tie', 8693), + ('Butterfly', 8436), + ('Beer', 7978), + ('Cabinetry', 7956), + ('Laptop', 7907), + ('Insect', 7497), + ('Goggles', 7363), + ('Shirt', 7098), + ('Dairy Product', 7021), + ('Marine invertebrates', 7014), + ('Cattle', 7006), + ('Trousers', 6903), + ('Van', 6843), + ('Billboard', 6777), + ('Balloon', 6367), + ('Human nose', 6103), + ('Tent', 6073), + ('Camera', 6014), + ('Doll', 6002), + ('Coat', 5951), + ('Mobile phone', 5758), + ('Swimwear', 5729), + ('Strawberry', 5691), + ('Stairs', 5643), + ('Goose', 5599), + ('Umbrella', 5536), + ('Cake', 5508), + ('Sun hat', 5475), + ('Bench', 5310), + ('Bookcase', 5163), + ('Bee', 5140), + ('Computer monitor', 5078), + ('Hiking equipment', 4983), + ('Office building', 4981), + ('Coffee cup', 4748), + ('Curtain', 4685), + ('Plate', 4651), + ('Box', 4621), + ('Tomato', 4595), + ('Coffee table', 4529), + ('Office supplies', 4473), + ('Maple', 4416), + ('Muffin', 4365), + ('Cocktail', 4234), + ('Castle', 4197), + ('Couch', 4134), + ('Pumpkin', 3983), + ('Computer keyboard', 3960), + ('Human mouth', 3926), + ('Christmas tree', 3893), + ('Mushroom', 3883), + ('Swimming pool', 3809), + ('Pastry', 3799), + ('Lavender (Plant)', 3769), + ('Football helmet', 3732), + ('Bread', 3648), + ('Traffic sign', 3628), + ('Common sunflower', 3597), + ('Television', 3550), + ('Bed', 3525), + ('Cookie', 3485), + ('Fountain', 3484), + ('Paddle', 3447), + ('Bicycle helmet', 3429), + ('Porch', 3420), + ('Deer', 3387), + ('Fedora', 3339), + ('Canoe', 3338), + ('Carnivore', 3266), + ('Bowl', 3202), + ('Human eye', 3166), + ('Ball', 3118), + ('Pillow', 3077), + ('Salad', 3061), + ('Beetle', 3060), + ('Orange', 3050), + ('Drawer', 2958), + ('Platter', 2937), + ('Elephant', 2921), + ('Seafood', 2921), + ('Monkey', 2915), + ('Countertop', 2879), + ('Watercraft', 2831), + ('Helicopter', 2805), + ('Kitchen appliance', 2797), + ('Personal flotation device', 2781), + ('Swan', 2739), + ('Lamp', 2711), + ('Boot', 2695), + ('Bronze sculpture', 2693), + ('Chicken', 2677), + ('Taxi', 2643), + ('Juice', 2615), + ('Cowboy hat', 2604), + ('Apple', 2600), + ('Tin can', 2590), + ('Necklace', 2564), + ('Ice cream', 2560), + ('Human beard', 2539), + ('Coin', 2536), + ('Candle', 2515), + ('Cart', 2512), + ('High heels', 2441), + ('Weapon', 2433), + ('Handbag', 2406), + ('Penguin', 2396), + ('Rifle', 2352), + ('Violin', 2336), + ('Skull', 2304), + ('Lantern', 2285), + ('Scarf', 2269), + ('Saucer', 2225), + ('Sheep', 2215), + ('Vase', 2189), + ('Lily', 2180), + ('Mug', 2154), + ('Parrot', 2140), + ('Human ear', 2137), + ('Sandal', 2115), + ('Lizard', 2100), + ('Kitchen & dining room table', 2063), + ('Spider', 1977), + ('Coffee', 1974), + ('Goat', 1926), + ('Squirrel', 1922), + ('Cello', 1913), + ('Sushi', 1881), + ('Tortoise', 1876), + ('Pizza', 1870), + ('Studio couch', 1864), + ('Barrel', 1862), + ('Cosmetics', 1841), + ('Moths and butterflies', 1841), + ('Convenience store', 1817), + ('Watch', 1792), + ('Home appliance', 1786), + ('Harbor seal', 1780), + ('Luggage and bags', 1756), + ('Vehicle registration plate', 1754), + ('Shrimp', 1751), + ('Jellyfish', 1730), + ('French fries', 1723), + ('Egg (Food)', 1698), + ('Football', 1697), + ('Musical keyboard', 1683), + ('Falcon', 1674), + ('Candy', 1660), + ('Medical equipment', 1654), + ('Eagle', 1651), + ('Dinosaur', 1634), + ('Surfboard', 1630), + ('Tank', 1628), + ('Grape', 1624), + ('Lion', 1624), + ('Owl', 1622), + ('Ski', 1613), + ('Waste container', 1606), + ('Frog', 1591), + ('Sparrow', 1585), + ('Rabbit', 1581), + ('Pen', 1546), + ('Sea lion', 1537), + ('Spoon', 1521), + ('Sink', 1512), + ('Teddy bear', 1507), + ('Bull', 1495), + ('Sofa bed', 1490), + ('Dragonfly', 1479), + ('Brassiere', 1478), + ('Chest of drawers', 1472), + ('Aircraft', 1466), + ('Human foot', 1463), + ('Pig', 1455), + ('Fork', 1454), + ('Antelope', 1438), + ('Tripod', 1427), + ('Tool', 1424), + ('Cheese', 1422), + ('Lemon', 1397), + ('Hamburger', 1393), + ('Dolphin', 1390), + ('Mirror', 1390), + ('Marine mammal', 1387), + ('Giraffe', 1385), + ('Snake', 1368), + ('Gondola', 1364), + ('Wheelchair', 1360), + ('Piano', 1358), + ('Cupboard', 1348), + ('Banana', 1345), + ('Trumpet', 1335), + ('Lighthouse', 1333), + ('Invertebrate', 1317), + ('Carrot', 1268), + ('Sock', 1260), + ('Tiger', 1241), + ('Camel', 1224), + ('Parachute', 1224), + ('Bathroom accessory', 1223), + ('Earrings', 1221), + ('Headphones', 1218), + ('Skirt', 1198), + ('Skateboard', 1190), + ('Sandwich', 1148), + ('Saxophone', 1141), + ('Goldfish', 1136), + ('Stool', 1104), + ('Traffic light', 1097), + ('Shellfish', 1081), + ('Backpack', 1079), + ('Sea turtle', 1078), + ('Cucumber', 1075), + ('Tea', 1051), + ('Toilet', 1047), + ('Roller skates', 1040), + ('Mule', 1039), + ('Bust', 1031), + ('Broccoli', 1030), + ('Crab', 1020), + ('Oyster', 1019), + ('Cannon', 1012), + ('Zebra', 1012), + ('French horn', 1008), + ('Grapefruit', 998), + ('Whiteboard', 997), + ('Zucchini', 997), + ('Crocodile', 992), + + ('Clock', 960), + ('Wall clock', 958), + + ('Doughnut', 869), + ('Snail', 868), + + ('Baseball glove', 859), + + ('Panda', 830), + ('Tennis racket', 830), + + ('Pear', 652), + + ('Bagel', 617), + ('Oven', 616), + ('Ladybug', 615), + ('Shark', 615), + ('Polar bear', 614), + ('Ostrich', 609), + + ('Hot dog', 473), + ('Microwave oven', 467), + ('Fire hydrant', 20), + ('Stop sign', 20), + ('Parking meter', 20), + ('Bear', 20), + ('Flying disc', 20), + ('Snowboard', 20), + ('Tennis ball', 20), + ('Kite', 20), + ('Baseball bat', 20), + ('Kitchen knife', 20), + ('Knife', 20), + ('Submarine sandwich', 20), + ('Computer mouse', 20), + ('Remote control', 20), + ('Toaster', 20), + ('Sink', 20), + ('Refrigerator', 20), + ('Alarm clock', 20), + ('Wall clock', 20), + ('Scissors', 20), + ('Hair dryer', 20), + ('Toothbrush', 20), + ('Suitcase', 20) +] diff --git a/deforum-stable-diffusion/src/taming/data/sflckr.py b/deforum-stable-diffusion/src/taming/data/sflckr.py new file mode 100644 index 0000000000000000000000000000000000000000..91101be5953b113f1e58376af637e43f366b3dee --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/sflckr.py @@ -0,0 +1,91 @@ +import os +import numpy as np +import cv2 +import albumentations +from PIL import Image +from torch.utils.data import Dataset + + +class SegmentationBase(Dataset): + def __init__(self, + data_csv, data_root, segmentation_root, + size=None, random_crop=False, interpolation="bicubic", + n_labels=182, shift_segmentation=False, + ): + self.n_labels = n_labels + self.shift_segmentation = shift_segmentation + self.data_csv = data_csv + self.data_root = data_root + self.segmentation_root = segmentation_root + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) + for l in self.image_paths] + } + + size = None if size is not None and size<=0 else size + self.size = size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + assert segmentation.mode == "L", segmentation.mode + segmentation = np.array(segmentation).astype(np.uint8) + if self.shift_segmentation: + # used to support segmentations containing unlabeled==255 label + segmentation = segmentation+1 + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, + mask=segmentation + ) + else: + processed = {"image": image, + "mask": segmentation + } + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class Examples(SegmentationBase): + def __init__(self, size=None, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/sflckr_examples.txt", + data_root="data/sflckr_images", + segmentation_root="data/sflckr_segmentations", + size=size, random_crop=random_crop, interpolation=interpolation) diff --git a/deforum-stable-diffusion/src/taming/data/utils.py b/deforum-stable-diffusion/src/taming/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3c3d53cd2b6c72b481b59834cf809d3735b394 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/data/utils.py @@ -0,0 +1,169 @@ +import collections +import os +import tarfile +import urllib +import zipfile +from pathlib import Path + +import numpy as np +import torch +from taming.data.helper_types import Annotation +from torch._six import string_classes +from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format +from tqdm import tqdm + + +def unpack(path): + if path.endswith("tar.gz"): + with tarfile.open(path, "r:gz") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("tar"): + with tarfile.open(path, "r:") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("zip"): + with zipfile.ZipFile(path, "r") as f: + f.extractall(path=os.path.split(path)[0]) + else: + raise NotImplementedError( + "Unknown file extension: {}".format(os.path.splitext(path)[1]) + ) + + +def reporthook(bar): + """tqdm progress bar for downloads.""" + + def hook(b=1, bsize=1, tsize=None): + if tsize is not None: + bar.total = tsize + bar.update(b * bsize - bar.n) + + return hook + + +def get_root(name): + base = "data/" + root = os.path.join(base, name) + os.makedirs(root, exist_ok=True) + return root + + +def is_prepared(root): + return Path(root).joinpath(".ready").exists() + + +def mark_prepared(root): + Path(root).joinpath(".ready").touch() + + +def prompt_download(file_, source, target_dir, content_dir=None): + targetpath = os.path.join(target_dir, file_) + while not os.path.exists(targetpath): + if content_dir is not None and os.path.exists( + os.path.join(target_dir, content_dir) + ): + break + print( + "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) + ) + if content_dir is not None: + print( + "Or place its content into '{}'.".format( + os.path.join(target_dir, content_dir) + ) + ) + input("Press Enter when done...") + return targetpath + + +def download_url(file_, url, target_dir): + targetpath = os.path.join(target_dir, file_) + os.makedirs(target_dir, exist_ok=True) + with tqdm( + unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ + ) as bar: + urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) + return targetpath + + +def download_urls(urls, target_dir): + paths = dict() + for fname, url in urls.items(): + outpath = download_url(fname, url, target_dir) + paths[fname] = outpath + return paths + + +def quadratic_crop(x, bbox, alpha=1.0): + """bbox is xmin, ymin, xmax, ymax""" + im_h, im_w = x.shape[:2] + bbox = np.array(bbox, dtype=np.float32) + bbox = np.clip(bbox, 0, max(im_h, im_w)) + center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + l = int(alpha * max(w, h)) + l = max(l, 2) + + required_padding = -1 * min( + center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) + ) + required_padding = int(np.ceil(required_padding)) + if required_padding > 0: + padding = [ + [required_padding, required_padding], + [required_padding, required_padding], + ] + padding += [[0, 0]] * (len(x.shape) - 2) + x = np.pad(x, padding, "reflect") + center = center[0] + required_padding, center[1] + required_padding + xmin = int(center[0] - l / 2) + ymin = int(center[1] - l / 2) + return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) + + +def custom_collate(batch): + r"""source: pytorch 1.9.0, only one modification to original code """ + + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return custom_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: custom_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(custom_collate(samples) for samples in zip(*batch))) + if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added + return batch # added + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [custom_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/deforum-stable-diffusion/src/taming/lr_scheduler.py b/deforum-stable-diffusion/src/taming/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..e598ed120159c53da6820a55ad86b89f5c70c82d --- /dev/null +++ b/deforum-stable-diffusion/src/taming/lr_scheduler.py @@ -0,0 +1,34 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n): + return self.schedule(n) + diff --git a/deforum-stable-diffusion/src/taming/models/cond_transformer.py b/deforum-stable-diffusion/src/taming/models/cond_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c63730fa86ac1b92b37af14c14fb696595b1ab --- /dev/null +++ b/deforum-stable-diffusion/src/taming/models/cond_transformer.py @@ -0,0 +1,352 @@ +import os, math +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from main import instantiate_from_config +from taming.modules.util import SOSProvider + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class Net2NetTransformer(pl.LightningModule): + def __init__(self, + transformer_config, + first_stage_config, + cond_stage_config, + permuter_config=None, + ckpt_path=None, + ignore_keys=[], + first_stage_key="image", + cond_stage_key="depth", + downsample_cond_size=-1, + pkeep=1.0, + sos_token=0, + unconditional=False, + ): + super().__init__() + self.be_unconditional = unconditional + self.sos_token = sos_token + self.first_stage_key = first_stage_key + self.cond_stage_key = cond_stage_key + self.init_first_stage_from_ckpt(first_stage_config) + self.init_cond_stage_from_ckpt(cond_stage_config) + if permuter_config is None: + permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} + self.permuter = instantiate_from_config(config=permuter_config) + self.transformer = instantiate_from_config(config=transformer_config) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.downsample_cond_size = downsample_cond_size + self.pkeep = pkeep + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + for k in sd.keys(): + for ik in ignore_keys: + if k.startswith(ik): + self.print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def init_first_stage_from_ckpt(self, config): + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + self.first_stage_model = model + + def init_cond_stage_from_ckpt(self, config): + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__" or self.be_unconditional: + print(f"Using no cond stage. Assuming the training is intended to be unconditional. " + f"Prepending {self.sos_token} as a sos token.") + self.be_unconditional = True + self.cond_stage_key = self.first_stage_key + self.cond_stage_model = SOSProvider(self.sos_token) + else: + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + self.cond_stage_model = model + + def forward(self, x, c): + # one step to produce the logits + _, z_indices = self.encode_to_z(x) + _, c_indices = self.encode_to_c(c) + + if self.training and self.pkeep < 1.0: + mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, + device=z_indices.device)) + mask = mask.round().to(dtype=torch.int64) + r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) + a_indices = mask*z_indices+(1-mask)*r_indices + else: + a_indices = z_indices + + cz_indices = torch.cat((c_indices, a_indices), dim=1) + + # target includes all sequence elements (no need to handle first one + # differently because we are conditioning) + target = z_indices + # make the prediction + logits, _ = self.transformer(cz_indices[:, :-1]) + # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: + c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) + quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) + if len(indices.shape) > 2: + indices = indices.view(c.shape[0], -1) + return quant_c, indices + + @torch.no_grad() + def decode_to_img(self, index, zshape): + index = self.permuter(index, reverse=True) + bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) + quant_z = self.first_stage_model.quantize.get_codebook_entry( + index.reshape(-1), shape=bhwc) + x = self.first_stage_model.decode(quant_z) + return x + + @torch.no_grad() + def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): + log = dict() + + N = 4 + if lr_interface: + x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) + else: + x, c = self.get_xc(batch, N) + x = x.to(device=self.device) + c = c.to(device=self.device) + + quant_z, z_indices = self.encode_to_z(x) + quant_c, c_indices = self.encode_to_c(c) + + # create a "half"" sample + z_start_indices = z_indices[:,:z_indices.shape[1]//2] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1]-z_start_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample = self.decode_to_img(index_sample, quant_z.shape) + + # sample + z_start_indices = z_indices[:, :0] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) + + # det sample + z_start_indices = z_indices[:, :0] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1], + sample=False, + callback=callback if callback is not None else lambda k: None) + x_sample_det = self.decode_to_img(index_sample, quant_z.shape) + + # reconstruction + x_rec = self.decode_to_img(z_indices, quant_z.shape) + + log["inputs"] = x + log["reconstructions"] = x_rec + + if self.cond_stage_key in ["objects_bbox", "objects_center_points"]: + figure_size = (x_rec.shape[2], x_rec.shape[3]) + dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"] + label_for_category_no = dataset.get_textual_label_for_category_no + plotter = dataset.conditional_builders[self.cond_stage_key].plot + log["conditioning"] = torch.zeros_like(log["reconstructions"]) + for i in range(quant_c.shape[0]): + log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size) + log["conditioning_rec"] = log["conditioning"] + elif self.cond_stage_key != "image": + cond_rec = self.cond_stage_model.decode(quant_c) + if self.cond_stage_key == "segmentation": + # get image from segmentation mask + num_classes = cond_rec.shape[1] + + c = torch.argmax(c, dim=1, keepdim=True) + c = F.one_hot(c, num_classes=num_classes) + c = c.squeeze(1).permute(0, 3, 1, 2).float() + c = self.cond_stage_model.to_rgb(c) + + cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) + cond_rec = F.one_hot(cond_rec, num_classes=num_classes) + cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() + cond_rec = self.cond_stage_model.to_rgb(cond_rec) + log["conditioning_rec"] = cond_rec + log["conditioning"] = c + + log["samples_half"] = x_sample + log["samples_nopix"] = x_sample_nopix + log["samples_det"] = x_sample_det + return log + + def get_input(self, key, batch): + x = batch[key] + if len(x.shape) == 3: + x = x[..., None] + if len(x.shape) == 4: + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + if x.dtype == torch.double: + x = x.float() + return x + + def get_xc(self, batch, N=None): + x = self.get_input(self.first_stage_key, batch) + c = self.get_input(self.cond_stage_key, batch) + if N is not None: + x = x[:N] + c = c[:N] + return x, c + + def shared_step(self, batch, batch_idx): + x, c = self.get_xc(batch) + logits, target = self(x, c) + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) + return loss + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def configure_optimizers(self): + """ + Following minGPT: + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.transformer.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.transformer.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) + return optimizer diff --git a/deforum-stable-diffusion/src/taming/models/dummy_cond_stage.py b/deforum-stable-diffusion/src/taming/models/dummy_cond_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..6e19938078752e09b926a3e749907ee99a258ca0 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/models/dummy_cond_stage.py @@ -0,0 +1,22 @@ +from torch import Tensor + + +class DummyCondStage: + def __init__(self, conditional_key): + self.conditional_key = conditional_key + self.train = None + + def eval(self): + return self + + @staticmethod + def encode(c: Tensor): + return c, None, (None, None, c) + + @staticmethod + def decode(c: Tensor): + return c + + @staticmethod + def to_rgb(c: Tensor): + return c diff --git a/deforum-stable-diffusion/src/taming/models/vqgan.py b/deforum-stable-diffusion/src/taming/models/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a6950baa5f739111cd64c17235dca8be3a5f8037 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/models/vqgan.py @@ -0,0 +1,404 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from main import instantiate_from_config + +from taming.modules.diffusionmodules.model import Encoder, Decoder +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from taming.modules.vqvae.quantize import GumbelQuantize +from taming.modules.vqvae.quantize import EMAVectorQuantizer + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def training_step(self, batch, batch_idx, optimizer_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQSegmentationModel(VQModel): + def __init__(self, n_labels, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + return opt_ae + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + total_loss = log_dict_ae["val/total_loss"] + self.log("val/total_loss", total_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + return aeloss + + @torch.no_grad() + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + # convert logits to indices + xrec = torch.argmax(xrec, dim=1, keepdim=True) + xrec = F.one_hot(xrec, num_classes=x.shape[1]) + xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + +class VQNoDiscModel(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None + ): + super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, + ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, + colorize_nlabels=colorize_nlabels) + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") + output = pl.TrainResult(minimize=aeloss) + output.log("train/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return output + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") + rec_loss = log_dict_ae["val/rec_loss"] + output = pl.EvalResult(checkpoint_on=rec_loss) + output.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae) + + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=self.learning_rate, betas=(0.5, 0.9)) + return optimizer + + +class GumbelVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + temperature_scheduler_config, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + kl_weight=1e-8, + remap=None, + ): + + z_channels = ddconfig["z_channels"] + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + + self.loss.n_classes = n_embed + self.vocab_size = n_embed + + self.quantize = GumbelQuantize(z_channels, embed_dim, + n_embed=n_embed, + kl_weight=kl_weight, temp_init=1.0, + remap=remap) + + self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def temperature_scheduling(self): + self.quantize.temperature = self.temperature_scheduler(self.global_step) + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode_code(self, code_b): + raise NotImplementedError + + def training_step(self, batch, batch_idx, optimizer_idx): + self.temperature_scheduling() + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + # encode + h = self.encoder(x) + h = self.quant_conv(h) + quant, _, _ = self.quantize(h) + # decode + x_rec = self.decode(quant) + log["inputs"] = x + log["reconstructions"] = x_rec + return log + + +class EMAVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + self.quantize = EMAVectorQuantizer(n_embed=n_embed, + embedding_dim=embed_dim, + beta=0.25, + remap=remap) + def configure_optimizers(self): + lr = self.learning_rate + #Remove self.quantize from parameter list since it is updated via EMA + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] \ No newline at end of file diff --git a/deforum-stable-diffusion/src/taming/modules/diffusionmodules/model.py b/deforum-stable-diffusion/src/taming/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a5db6aa2ef915e270f1ae135e4a9918fdd884c --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/diffusionmodules/model.py @@ -0,0 +1,776 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + diff --git a/deforum-stable-diffusion/src/taming/modules/discriminator/model.py b/deforum-stable-diffusion/src/taming/modules/discriminator/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2aaa3110d0a7bcd05de7eca1e45101589ca5af05 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/discriminator/model.py @@ -0,0 +1,67 @@ +import functools +import torch.nn as nn + + +from taming.modules.util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/deforum-stable-diffusion/src/taming/modules/losses/__init__.py b/deforum-stable-diffusion/src/taming/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d09caf9eb805f849a517f1b23503e1a4d6ea1ec5 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/losses/__init__.py @@ -0,0 +1,2 @@ +from taming.modules.losses.vqperceptual import DummyLoss + diff --git a/deforum-stable-diffusion/src/taming/modules/losses/lpips.py b/deforum-stable-diffusion/src/taming/modules/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..a7280447694ffc302a7636e7e4d6183408e0aa95 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/losses/lpips.py @@ -0,0 +1,123 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + +from taming.util import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) + diff --git a/deforum-stable-diffusion/src/taming/modules/losses/segmentation.py b/deforum-stable-diffusion/src/taming/modules/losses/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba77deb5159a6307ed2acba9945e4764a4ff0a5 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/losses/segmentation.py @@ -0,0 +1,22 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class BCELoss(nn.Module): + def forward(self, prediction, target): + loss = F.binary_cross_entropy_with_logits(prediction,target) + return loss, {} + + +class BCELossWithQuant(nn.Module): + def __init__(self, codebook_weight=1.): + super().__init__() + self.codebook_weight = codebook_weight + + def forward(self, qloss, target, prediction, split): + bce_loss = F.binary_cross_entropy_with_logits(prediction,target) + loss = bce_loss + self.codebook_weight*qloss + return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/bce_loss".format(split): bce_loss.detach().mean(), + "{}/quant_loss".format(split): qloss.detach().mean() + } diff --git a/deforum-stable-diffusion/src/taming/modules/losses/vqperceptual.py b/deforum-stable-diffusion/src/taming/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..c2febd445728479d4cd9aacdb2572cb1f1af04db --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/losses/vqperceptual.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from taming.modules.losses.lpips import LPIPS +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init + + +class DummyLoss(nn.Module): + def __init__(self): + super().__init__() + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train"): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/deforum-stable-diffusion/src/taming/modules/misc/coord.py b/deforum-stable-diffusion/src/taming/modules/misc/coord.py new file mode 100644 index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/misc/coord.py @@ -0,0 +1,31 @@ +import torch + +class CoordStage(object): + def __init__(self, n_embed, down_factor): + self.n_embed = n_embed + self.down_factor = down_factor + + def eval(self): + return self + + def encode(self, c): + """fake vqmodel interface""" + assert 0.0 <= c.min() and c.max() <= 1.0 + b,ch,h,w = c.shape + assert ch == 1 + + c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, + mode="area") + c = c.clamp(0.0, 1.0) + c = self.n_embed*c + c_quant = c.round() + c_ind = c_quant.to(dtype=torch.long) + + info = None, None, c_ind + return c_quant, None, info + + def decode(self, c): + c = c/self.n_embed + c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, + mode="nearest") + return c diff --git a/deforum-stable-diffusion/src/taming/modules/transformer/mingpt.py b/deforum-stable-diffusion/src/taming/modules/transformer/mingpt.py new file mode 100644 index 0000000000000000000000000000000000000000..d14b7b68117f4b9f297b2929397cd4f55089334c --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/transformer/mingpt.py @@ -0,0 +1,415 @@ +""" +taken from: https://github.com/karpathy/minGPT/ +GPT model: +- the initial stem consists of a combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of Transformer blocks + - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway similar to resnets +- the final decoder is a linear projection into a vanilla Softmax classifier +""" + +import math +import logging + +import torch +import torch.nn as nn +from torch.nn import functional as F +from transformers import top_k_top_p_filtering + +logger = logging.getLogger(__name__) + + +class GPTConfig: + """ base GPT config, params common to all GPT versions """ + embd_pdrop = 0.1 + resid_pdrop = 0.1 + attn_pdrop = 0.1 + + def __init__(self, vocab_size, block_size, **kwargs): + self.vocab_size = vocab_size + self.block_size = block_size + for k,v in kwargs.items(): + setattr(self, k, v) + + +class GPT1Config(GPTConfig): + """ GPT-1 like network roughly 125M params """ + n_layer = 12 + n_head = 12 + n_embd = 768 + + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + mask = torch.tril(torch.ones(config.block_size, + config.block_size)) + if hasattr(config, "n_unmasked"): + mask[:config.n_unmasked, :config.n_unmasked] = 1 + self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + present = torch.stack((k, v)) + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if layer_past is None: + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y, present # TODO: check that this does not break anything + + +class Block(nn.Module): + """ an unassuming Transformer block """ + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), # nice + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x, layer_past=None, return_present=False): + # TODO: check that training still works + if return_present: assert not self.training + # layer past: tuple of length two with B, nh, T, hs + attn, present = self.attn(self.ln1(x), layer_past=layer_past) + + x = x + attn + x = x + self.mlp(self.ln2(x)) + if layer_past is not None or return_present: + return x, present + return x + + +class GPT(nn.Module): + """ the full GPT language model, with a context size of block_size """ + def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, + embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + super().__init__() + config = GPTConfig(vocab_size=vocab_size, block_size=block_size, + embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, + n_layer=n_layer, n_head=n_head, n_embd=n_embd, + n_unmasked=n_unmasked) + # input embedding stem + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.block_size = config.block_size + self.apply(self._init_weights) + self.config = config + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, embeddings=None, targets=None): + # forward the GPT model + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + t = token_embeddings.shape[1] + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss + + def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None): + # inference only + assert not self.training + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + if past is not None: + assert past_length is not None + past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head + past_shape = list(past.shape) + expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] + assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" + position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector + else: + position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] + + x = self.drop(token_embeddings + position_embeddings) + presents = [] # accumulate over layers + for i, block in enumerate(self.blocks): + x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True) + presents.append(present) + + x = self.ln_f(x) + logits = self.head(x) + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head + + +class DummyGPT(nn.Module): + # for debugging + def __init__(self, add_value=1): + super().__init__() + self.add_value = add_value + + def forward(self, idx): + return idx + self.add_value, None + + +class CodeGPT(nn.Module): + """Takes in semi-embeddings""" + def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, + embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + super().__init__() + config = GPTConfig(vocab_size=vocab_size, block_size=block_size, + embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, + n_layer=n_layer, n_head=n_head, n_embd=n_embd, + n_unmasked=n_unmasked) + # input embedding stem + self.tok_emb = nn.Linear(in_channels, config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.block_size = config.block_size + self.apply(self._init_weights) + self.config = config + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, embeddings=None, targets=None): + # forward the GPT model + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + t = token_embeddings.shape[1] + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.taming_cinln_f(x) + logits = self.head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss + + + +#### sampling utils + +def top_k_logits(logits, k): + v, ix = torch.topk(logits, k) + out = logits.clone() + out[out < v[:, [-1]]] = -float('Inf') + return out + +@torch.no_grad() +def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): + """ + take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in + the sequence, feeding the predictions back into the model each time. Clearly the sampling + has quadratic complexity unlike an RNN that is only linear, and has a finite context window + of block_size, unlike an RNN that has an infinite context window. + """ + block_size = model.get_block_size() + model.eval() + for k in range(steps): + x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed + logits, _ = model(x_cond) + # pluck the logits at the final step and scale by temperature + logits = logits[:, -1, :] / temperature + # optionally crop probabilities to only the top k options + if top_k is not None: + logits = top_k_logits(logits, top_k) + # apply softmax to convert to probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution or take the most likely + if sample: + ix = torch.multinomial(probs, num_samples=1) + else: + _, ix = torch.topk(probs, k=1, dim=-1) + # append to the sequence and continue + x = torch.cat((x, ix), dim=1) + + return x + + +@torch.no_grad() +def sample_with_past(x, model, steps, temperature=1., sample_logits=True, + top_k=None, top_p=None, callback=None): + # x is conditioning + sample = x + cond_len = x.shape[1] + past = None + for n in range(steps): + if callback is not None: + callback(n) + logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1)) + if past is None: + past = [present] + else: + past.append(present) + logits = logits[:, -1, :] / temperature + if top_k is not None: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + + probs = F.softmax(logits, dim=-1) + if not sample_logits: + _, x = torch.topk(probs, k=1, dim=-1) + else: + x = torch.multinomial(probs, num_samples=1) + # append to the sequence and continue + sample = torch.cat((sample, x), dim=1) + del past + sample = sample[:, cond_len:] # cut conditioning off + return sample + + +#### clustering utils + +class KMeans(nn.Module): + def __init__(self, ncluster=512, nc=3, niter=10): + super().__init__() + self.ncluster = ncluster + self.nc = nc + self.niter = niter + self.shape = (3,32,32) + self.register_buffer("C", torch.zeros(self.ncluster,nc)) + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def is_initialized(self): + return self.initialized.item() == 1 + + @torch.no_grad() + def initialize(self, x): + N, D = x.shape + assert D == self.nc, D + c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random + for i in range(self.niter): + # assign all pixels to the closest codebook element + a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1) + # move each codebook element to be the mean of the pixels that assigned to it + c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)]) + # re-assign any poorly positioned codebook elements + nanix = torch.any(torch.isnan(c), dim=1) + ndead = nanix.sum().item() + print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead)) + c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters + + self.C.copy_(c) + self.initialized.fill_(1) + + + def forward(self, x, reverse=False, shape=None): + if not reverse: + # flatten + bs,c,h,w = x.shape + assert c == self.nc + x = x.reshape(bs,c,h*w,1) + C = self.C.permute(1,0) + C = C.reshape(1,c,1,self.ncluster) + a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices + return a + else: + # flatten + bs, HW = x.shape + """ + c = self.C.reshape( 1, self.nc, 1, self.ncluster) + c = c[bs*[0],:,:,:] + c = c[:,:,HW*[0],:] + x = x.reshape(bs, 1, HW, 1) + x = x[:,3*[0],:,:] + x = torch.gather(c, dim=3, index=x) + """ + x = self.C[x] + x = x.permute(0,2,1) + shape = shape if shape is not None else self.shape + x = x.reshape(bs, *shape) + + return x diff --git a/deforum-stable-diffusion/src/taming/modules/transformer/permuter.py b/deforum-stable-diffusion/src/taming/modules/transformer/permuter.py new file mode 100644 index 0000000000000000000000000000000000000000..0d43bb135adde38d94bf18a7e5edaa4523cd95cf --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/transformer/permuter.py @@ -0,0 +1,248 @@ +import torch +import torch.nn as nn +import numpy as np + + +class AbstractPermuter(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + def forward(self, x, reverse=False): + raise NotImplementedError + + +class Identity(AbstractPermuter): + def __init__(self): + super().__init__() + + def forward(self, x, reverse=False): + return x + + +class Subsample(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + C = 1 + indices = np.arange(H*W).reshape(C,H,W) + while min(H, W) > 1: + indices = indices.reshape(C,H//2,2,W//2,2) + indices = indices.transpose(0,2,4,1,3) + indices = indices.reshape(C*4,H//2, W//2) + H = H//2 + W = W//2 + C = C*4 + assert H == W == 1 + idx = torch.tensor(indices.ravel()) + self.register_buffer('forward_shuffle_idx', + nn.Parameter(idx, requires_grad=False)) + self.register_buffer('backward_shuffle_idx', + nn.Parameter(torch.argsort(idx), requires_grad=False)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +def mortonify(i, j): + """(i,j) index to linear morton code""" + i = np.uint64(i) + j = np.uint64(j) + + z = np.uint(0) + + for pos in range(32): + z = (z | + ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | + ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) + ) + return z + + +class ZCurve(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] + idx = np.argsort(reverseidx) + idx = torch.tensor(idx) + reverseidx = torch.tensor(reverseidx) + self.register_buffer('forward_shuffle_idx', + idx) + self.register_buffer('backward_shuffle_idx', + reverseidx) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class SpiralOut(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + assert H == W + size = W + indices = np.arange(size*size).reshape(size,size) + + i0 = size//2 + j0 = size//2-1 + + i = i0 + j = j0 + + idx = [indices[i0, j0]] + step_mult = 0 + for c in range(1, size//2+1): + step_mult += 1 + # steps left + for k in range(step_mult): + i = i - 1 + j = j + idx.append(indices[i, j]) + + # step down + for k in range(step_mult): + i = i + j = j + 1 + idx.append(indices[i, j]) + + step_mult += 1 + if c < size//2: + # step right + for k in range(step_mult): + i = i + 1 + j = j + idx.append(indices[i, j]) + + # step up + for k in range(step_mult): + i = i + j = j - 1 + idx.append(indices[i, j]) + else: + # end reached + for k in range(step_mult-1): + i = i + 1 + idx.append(indices[i, j]) + + assert len(idx) == size*size + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class SpiralIn(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + assert H == W + size = W + indices = np.arange(size*size).reshape(size,size) + + i0 = size//2 + j0 = size//2-1 + + i = i0 + j = j0 + + idx = [indices[i0, j0]] + step_mult = 0 + for c in range(1, size//2+1): + step_mult += 1 + # steps left + for k in range(step_mult): + i = i - 1 + j = j + idx.append(indices[i, j]) + + # step down + for k in range(step_mult): + i = i + j = j + 1 + idx.append(indices[i, j]) + + step_mult += 1 + if c < size//2: + # step right + for k in range(step_mult): + i = i + 1 + j = j + idx.append(indices[i, j]) + + # step up + for k in range(step_mult): + i = i + j = j - 1 + idx.append(indices[i, j]) + else: + # end reached + for k in range(step_mult-1): + i = i + 1 + idx.append(indices[i, j]) + + assert len(idx) == size*size + idx = idx[::-1] + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class Random(nn.Module): + def __init__(self, H, W): + super().__init__() + indices = np.random.RandomState(1).permutation(H*W) + idx = torch.tensor(indices.ravel()) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class AlternateParsing(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + indices = np.arange(W*H).reshape(H,W) + for i in range(1, H, 2): + indices[i, :] = indices[i, ::-1] + idx = indices.flatten() + assert len(idx) == H*W + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +if __name__ == "__main__": + p0 = AlternateParsing(16, 16) + print(p0.forward_shuffle_idx) + print(p0.backward_shuffle_idx) + + x = torch.randint(0, 768, size=(11, 256)) + y = p0(x) + xre = p0(y, reverse=True) + assert torch.equal(x, xre) + + p1 = SpiralOut(2, 2) + print(p1.forward_shuffle_idx) + print(p1.backward_shuffle_idx) diff --git a/deforum-stable-diffusion/src/taming/modules/util.py b/deforum-stable-diffusion/src/taming/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/util.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn + + +def count_params(model): + total_params = sum(p.numel() for p in model.parameters()) + return total_params + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class Labelator(AbstractEncoder): + """Net2Net Interface for Class-Conditional Model""" + def __init__(self, n_classes, quantize_interface=True): + super().__init__() + self.n_classes = n_classes + self.quantize_interface = quantize_interface + + def encode(self, c): + c = c[:,None] + if self.quantize_interface: + return c, None, [None, None, c.long()] + return c + + +class SOSProvider(AbstractEncoder): + # for unconditional training + def __init__(self, sos_token, quantize_interface=True): + super().__init__() + self.sos_token = sos_token + self.quantize_interface = quantize_interface + + def encode(self, x): + # get batch size from data and replicate sos_token + c = torch.ones(x.shape[0], 1)*self.sos_token + c = c.long().to(x.device) + if self.quantize_interface: + return c, None, [None, None, c] + return c diff --git a/deforum-stable-diffusion/src/taming/modules/vqvae/quantize.py b/deforum-stable-diffusion/src/taming/modules/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/modules/vqvae/quantize.py @@ -0,0 +1,445 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:,self.used,...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:,self.used,...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b*h*w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + #EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + #EMA embedding average + embed_sum = encodings.transpose(0,1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + #normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/deforum-stable-diffusion/src/taming/util.py b/deforum-stable-diffusion/src/taming/util.py new file mode 100644 index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7 --- /dev/null +++ b/deforum-stable-diffusion/src/taming/util.py @@ -0,0 +1,157 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = {"keya": "a", + "keyb": "b", + "keyc": + {"cc1": 1, + "cc2": 2, + } + } + from omegaconf import OmegaConf + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") + diff --git a/deforum-stable-diffusion/src/temp.txt b/deforum-stable-diffusion/src/temp.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/deforum-stable-diffusion/src/temp.txt @@ -0,0 +1 @@ + diff --git a/deforum-stable-diffusion/src/types/inference.py b/deforum-stable-diffusion/src/types/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c32fdddf063c62f12da93df618b259c8c089f6e4 --- /dev/null +++ b/deforum-stable-diffusion/src/types/inference.py @@ -0,0 +1,248 @@ +from datetime import time +import io +import os +import numpy as np +from pydantic import BaseModel, Field +from typing import List, Literal, Optional, Union +import requests + +import torch +from helpers.save_images import get_output_folder +from PIL import Image +from PIL.GifImagePlugin import GifImageFile +from PIL.JpegImagePlugin import JpegImageFile +from PIL.PngImagePlugin import PngImageFile +from PIL.TiffImagePlugin import TiffImageFile +from pydantic import ( + BaseConfig, + BaseModel, + Field, + validator, +) +import validators +from devtools import debug + +ImageType = Union[JpegImageFile, PngImageFile, GifImageFile, TiffImageFile, Image.Image] + + +def is_image(v): + return ( + isinstance(v, PngImageFile) + or isinstance(v, JpegImageFile) + or isinstance(v, GifImageFile) + or isinstance(v, TiffImageFile) + or isinstance(v, Image.Image) + ) + + +# Helper for images +def validate_image(v, throw=True): + if v is None: + return v + elif ( + isinstance(v, PngImageFile) + or isinstance(v, JpegImageFile) + or isinstance(v, GifImageFile) + or isinstance(v, TiffImageFile) + or isinstance(v, Image.Image) + ): + return v + elif isinstance(v, bytes) or isinstance(v, io.BytesIO): + return v + elif isinstance(v, str): + if validators.url(v): + try: + v = Image.open(requests.get(v, stream=True).raw).convert("RGB") + return v + except Exception as err: + if throw: + raise ValueError( + "Invalid remote url, failed to parse image" + ) from err + else: + return False + elif os.path.isfile(v): + try: + with Image.open(v) as fd: + return fd.convert("RGB") + except Exception as err: + if throw: + raise ValueError( + "Invalid path, failed to parse image from local path" + ) from err + else: + return False + else: + if throw: + raise ValueError("Invalid string, no image or remote url") + else: + return False + else: + if throw: + raise ValueError( + f"Bad image type. Expected: bytes, Image, or Image url. Got: {debug.format(v)}" + ) + else: + return False + + +def output_folder_factory(output_path="outputs", batch_folder="deforum"): + prefix = os.path.abspath(os.path.dirname(__file__)) + return get_output_folder(f"{prefix}/{output_path}".replace("//", "/"), batch_folder) + + +class DeforumArgs(BaseModel): + W: Optional[int] = 512 + H: Optional[int] = 512 + seed: Optional[int] = -1 + sampler: Optional[ + Literal[ + "klms", + "dpm2", + "dpm2_ancestral", + "heun", + "euler", + "euler_ancestral", + "plms", + "ddim", + ] + ] = "euler_ancestral" + steps: Optional[int] = 80 + scale: Optional[int] = 7 + ddim_eta: Optional[float] = 0.0 + dynamic_threshold: Optional[float] = None + static_threshold: Optional[float] = None + save_samples: Optional[bool] = True + save_settings: Optional[bool] = True + display_samples: Optional[bool] = True + save_sample_per_step: Optional[bool] = False + show_sample_per_step: Optional[bool] = False + prompt_weighting: Optional[bool] = False + normalize_prompt_weights: Optional[bool] = False + log_weighted_subprompts: Optional[bool] = False + n_batch: Optional[int] = 1 + batch_name: Optional[str] = "StableFun" + filename_format: Optional[ + Literal["{timestring}_{index}_{seed}.png", "{timestring}_{index}_{prompt}.png"] + ] = "{timestring}_{index}_{prompt}.png" + seed_behavior: Optional[Literal["iter", "constant", "random"]] = "iter" + make_grid: Optional[bool] = False + grid_rows: Optional[int] = 2 + outdir: Optional[str] = Field(default_factory=output_folder_factory) + use_init: Optional[bool] = False + strength: Optional[float] = 0.0 + strength_0_no_init: Optional[bool] = True + init_image: Optional[ + ImageType + ] = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" + use_mask: Optional[bool] = False + use_alpha_as_mask: Optional[bool] = False + mask_file: Optional[ + ImageType + ] = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" + invert_mask: Optional[bool] = False + mask_brightness_adjust: Optional[float] = 1.0 + mask_contrast_adjust: Optional[float] = 1.0 + overlay_mask: Optional[bool] = True + mask_overlay_blur: Optional[float] = 5 + mean_loss_scale: Optional[float] = 0 + var_loss_scale: Optional[float] = 0 + exposure_loss_scale: Optional[float] = 0 + exposure_target: Optional[float] = 0.5 + colormatch_loss_scale: Optional[float] = 0 + colormatch_image: Optional[ + ImageType + ] = "https://www.saasdesign.io/wp-content/uploads/2021/02/palette-3-min-980x588.png" + colormatch_n_colors: Optional[int] = 4 + ignore_sat_scale: Optional[float] = 0 + clip_name: Optional[ + Literal["ViT-L/14", "ViT-L/14@336px", "ViT-B/16", "ViT-B/32"] + ] = "ViT-L/14" + clip_loss_scale: Optional[float] = 0 + aesthetics_loss_scale: Optional[float] = 0 + cutn: Optional[int] = 1 + cut_pow: Optional[float] = 0.0001 + init_mse_scale: Optional[float] = 0 + blue_loss_scale: Optional[float] = 0 + gradient_wrt: Optional[Literal["x", "x0_pred"]] = "x0_pred" + gradient_add_to: Optional[Literal["cond", "uncond", "both"]] = "both" + decode_method: Optional[Literal["autoencoder", "linear"]] = "linear" + grad_threshold_type: Optional[ + Literal["dynamic", "static", "mean", "schedule"] + ] = "dynamic" + clamp_grad_threshold: Optional[float] = 0.2 + clamp_start: Optional[float] = 0.2 + clamp_stop: Optional[float] = 0.01 + cond_uncond_sync: Optional[bool] = True + n_samples: Optional[int] = 1 + precision: Optional[Literal["fp16", "autocast", "fp32"]] = "autocast" + C: Optional[int] = 4 + f: Optional[int] = 8 + prompt: Optional[str] = "" + timestring: Optional[str] = Field( + default_factory=lambda: time.strftime("%Y%m%d%H%M%S") + ) + init_latent: Optional[Union[float, torch.Tensor, np.ndarray]] = None + init_sample: Optional[Union[float, torch.Tensor, np.ndarray]] = None + init_c: Optional[Union[float, torch.Tensor, np.ndarray]] = None + + class Config(BaseConfig): + arbitrary_types_allowed: Optional[bool] = True + + @validator("init_image", pre=True) + def validate_image_init(cls, v): + return validate_image(v) + + @validator("colormatch_image", pre=True) + def validate_image_colormatch(cls, v): + return validate_image(v) + + @validator("mask_file", pre=True) + def validate_image_mask(cls, v): + return validate_image(v) + + +class DeforumAnimArgs(BaseModel): + animation_mode: Optional[ + Literal["None", "2D", "3D", "Video Input", "Interpolation"] + ] = "None" + max_frames: Optional[int] = 1000 + border: Optional[Literal["wrap", "replicate"]] = "replicate" + angle: Optional[str] = "0:(0)" + zoom: Optional[str] = "0:(1.04)" + translation_x: Optional[str] = "0:(10*sin(2*3.14*t/10))" + translation_y: Optional[str] = "0:(0)" + translation_z: Optional[str] = "0:(10)" + rotation_3d_x: Optional[str] = "0:(0)" + rotation_3d_y: Optional[str] = "0:(0)" + rotation_3d_z: Optional[str] = "0:(0)" + flip_2d_perspective: Optional[bool] = False + perspective_flip_theta: Optional[str] = "0:(0)" + perspective_flip_phi: Optional[str] = "0:(t%15)" + perspective_flip_gamma: Optional[str] = "0:(0)" + perspective_flip_fv: Optional[str] = "0:(53)" + noise_schedule: Optional[str] = "0: (0.02)" + strength_schedule: Optional[str] = "0: (0.65)" + contrast_schedule: Optional[str] = "0: (1.0)" + color_coherence: Optional[str] = "Match Frame 0 LAB" + diffusion_cadence: Optional[Literal["1", "2", "3", "4", "5", "6", "7", "8"]] = "1" + use_depth_warping: Optional[bool] = True + midas_weight: Optional[float] = 0.3 + near_plane: Optional[int] = 200 + far_plane: Optional[int] = 10000 + fov: Optional[float] = 40 + padding_mode: Optional[Literal["border", "reflection", "zeros"]] = "border" + sampling_mode: Optional[Literal["bicubic", "bilinear", "nearest"]] = "bicubic" + save_depth_maps: Optional[bool] = False + video_init_path: Optional[str] = "/content/video_in.mp4" + extract_nth_frame: Optional[int] = 1 + overwrite_extracted_frames: Optional[bool] = True + use_mask_video: Optional[bool] = False + video_mask_path: Optional[str] = "/content/video_in.mp4" + interpolate_key_frames: Optional[bool] = False + interpolate_x_frames: Optional[int] = 4 + resume_from_timestring: Optional[bool] = False + resume_timestring: Optional[str] = "20220829210106" + + class Config(BaseConfig): + arbitrary_types_allowed: Optional[bool] = True diff --git a/deforum-stable-diffusion/src/utils.py b/deforum-stable-diffusion/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe08b0b1bd41f2bc59e9f8d188db08423fcf48a --- /dev/null +++ b/deforum-stable-diffusion/src/utils.py @@ -0,0 +1,140 @@ +import base64 +import math +import re +from io import BytesIO + +import matplotlib.cm +import numpy as np +import torch +import torch.nn +from PIL import Image + + +class RunningAverage: + def __init__(self): + self.avg = 0 + self.count = 0 + + def append(self, value): + self.avg = (value + self.count * self.avg) / (self.count + 1) + self.count += 1 + + def get_value(self): + return self.avg + + +def denormalize(x, device='cpu'): + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + return x * std + mean + + +class RunningAverageDict: + def __init__(self): + self._dict = None + + def update(self, new_dict): + if self._dict is None: + self._dict = dict() + for key, value in new_dict.items(): + self._dict[key] = RunningAverage() + + for key, value in new_dict.items(): + self._dict[key].append(value) + + def get_value(self): + return {key: value.get_value() for key, value in self._dict.items()} + + +def colorize(value, vmin=10, vmax=1000, cmap='magma_r'): + value = value.cpu().numpy()[0, :, :] + invalid_mask = value == -1 + + # normalize + vmin = value.min() if vmin is None else vmin + vmax = value.max() if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + # squeeze last dim if it exists + # value = value.squeeze(axis=0) + cmapper = matplotlib.cm.get_cmap(cmap) + value = cmapper(value, bytes=True) # (nxmx4) + value[invalid_mask] = 255 + img = value[:, :, :3] + + # return img.transpose((2, 0, 1)) + return img + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def compute_errors(gt, pred): + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + abs_rel = np.mean(np.abs(gt - pred) / gt) + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + err = np.log(pred) - np.log(gt) + silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 + + log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() + return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, + silog=silog, sq_rel=sq_rel) + + +##################################### Demo Utilities ############################################ +def b64_to_pil(b64string): + image_data = re.sub('^data:image/.+;base64,', '', b64string) + # image = Image.open(cStringIO.StringIO(image_data)) + return Image.open(BytesIO(base64.b64decode(image_data))) + + +# Compute edge magnitudes +from scipy import ndimage + + +def edges(d): + dx = ndimage.sobel(d, 0) # horizontal derivative + dy = ndimage.sobel(d, 1) # vertical derivative + return np.abs(dx) + np.abs(dy) + + +class PointCloudHelper(): + def __init__(self, width=640, height=480): + self.xx, self.yy = self.worldCoords(width, height) + + def worldCoords(self, width=640, height=480): + hfov_degrees, vfov_degrees = 57, 43 + hFov = math.radians(hfov_degrees) + vFov = math.radians(vfov_degrees) + cx, cy = width / 2, height / 2 + fx = width / (2 * math.tan(hFov / 2)) + fy = height / (2 * math.tan(vFov / 2)) + xx, yy = np.tile(range(width), height), np.repeat(range(height), width) + xx = (xx - cx) / fx + yy = (yy - cy) / fy + return xx, yy + + def depth_to_points(self, depth): + depth[edges(depth) > 0.3] = np.nan # Hide depth edges + length = depth.shape[0] * depth.shape[1] + # depth[edges(depth) > 0.3] = 1e6 # Hide depth edges + z = depth.reshape(length) + + return np.dstack((self.xx * z, self.yy * z, z)).reshape((length, 3)) + +##################################################################################################### diff --git a/environments/README.md b/environments/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a60542d3e7c7d9b387101399d7690b03da8c4b37 --- /dev/null +++ b/environments/README.md @@ -0,0 +1,2 @@ +This file folder includes the environment to run the video stable diffusion. + diff --git a/geffnet/__init__.py b/geffnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e441a5838d1e972823b9668ac8d459445f6f6ce --- /dev/null +++ b/geffnet/__init__.py @@ -0,0 +1,5 @@ +from .gen_efficientnet import * +from .mobilenetv3 import * +from .model_factory import create_model +from .config import is_exportable, is_scriptable, set_exportable, set_scriptable +from .activations import * \ No newline at end of file diff --git a/geffnet/__pycache__/__init__.cpython-39.pyc b/geffnet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50d0bf36f50c0fa7296f6326c2e7a6f260e3d50d Binary files /dev/null and b/geffnet/__pycache__/__init__.cpython-39.pyc differ diff --git a/geffnet/__pycache__/config.cpython-39.pyc b/geffnet/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3075a27729e1ca7e9ae4ae77fdd7488810d30305 Binary files /dev/null and b/geffnet/__pycache__/config.cpython-39.pyc differ diff --git a/geffnet/__pycache__/conv2d_layers.cpython-39.pyc b/geffnet/__pycache__/conv2d_layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31738c7a45cc34d70e9279aa95fb0fa899dab259 Binary files /dev/null and b/geffnet/__pycache__/conv2d_layers.cpython-39.pyc differ diff --git a/geffnet/__pycache__/efficientnet_builder.cpython-39.pyc b/geffnet/__pycache__/efficientnet_builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb2f05bf96c24d271ef5dedef82e0c196ba11db2 Binary files /dev/null and b/geffnet/__pycache__/efficientnet_builder.cpython-39.pyc differ diff --git a/geffnet/__pycache__/gen_efficientnet.cpython-39.pyc b/geffnet/__pycache__/gen_efficientnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fb9c4acdf51df76278212f368161b952df8eaaf Binary files /dev/null and b/geffnet/__pycache__/gen_efficientnet.cpython-39.pyc differ diff --git a/geffnet/__pycache__/helpers.cpython-39.pyc b/geffnet/__pycache__/helpers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86913943966e8e7f905b8813b6f0a85e25097bdc Binary files /dev/null and b/geffnet/__pycache__/helpers.cpython-39.pyc differ diff --git a/geffnet/__pycache__/mobilenetv3.cpython-39.pyc b/geffnet/__pycache__/mobilenetv3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d15b885223a5ca902e6c45e44533efaa77a3f819 Binary files /dev/null and b/geffnet/__pycache__/mobilenetv3.cpython-39.pyc differ diff --git a/geffnet/__pycache__/model_factory.cpython-39.pyc b/geffnet/__pycache__/model_factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..918095b79efaa62a6fa26bdb331160fb06a0fc12 Binary files /dev/null and b/geffnet/__pycache__/model_factory.cpython-39.pyc differ diff --git a/geffnet/activations/__init__.py b/geffnet/activations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..813421a743ffc33b8eb53ebf62dd4a03d831b654 --- /dev/null +++ b/geffnet/activations/__init__.py @@ -0,0 +1,137 @@ +from geffnet import config +from geffnet.activations.activations_me import * +from geffnet.activations.activations_jit import * +from geffnet.activations.activations import * +import torch + +_has_silu = 'silu' in dir(torch.nn.functional) + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=mish, + relu=F.relu, + relu6=F.relu6, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=hard_sigmoid, + hard_swish=hard_swish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=mish_jit, +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=mish_me, + hard_swish=hard_swish_me, + hard_sigmoid_jit=hard_sigmoid_me, +) + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=HardSigmoid, + hard_swish=HardSwish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=MishJit, +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=MishMe, + hard_swish=HardSwishMe, + hard_sigmoid=HardSigmoidMe +) + +_OVERRIDE_FN = dict() +_OVERRIDE_LAYER = dict() + + +def add_override_act_fn(name, fn): + global _OVERRIDE_FN + _OVERRIDE_FN[name] = fn + + +def update_override_act_fn(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_FN + _OVERRIDE_FN.update(overrides) + + +def clear_override_act_fn(): + global _OVERRIDE_FN + _OVERRIDE_FN = dict() + + +def add_override_act_layer(name, fn): + _OVERRIDE_LAYER[name] = fn + + +def update_override_act_layer(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_LAYER + _OVERRIDE_LAYER.update(overrides) + + +def clear_override_act_layer(): + global _OVERRIDE_LAYER + _OVERRIDE_LAYER = dict() + + +def get_act_fn(name='relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_FN: + return _OVERRIDE_FN[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_FN_ME: + # If not exporting or scripting the model, first look for a memory optimized version + # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin + return _ACT_FN_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name='relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_LAYER: + return _OVERRIDE_LAYER[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + diff --git a/geffnet/activations/__pycache__/__init__.cpython-39.pyc b/geffnet/activations/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eaace74469a026f0fab564776e7941f488fbc6e Binary files /dev/null and b/geffnet/activations/__pycache__/__init__.cpython-39.pyc differ diff --git a/geffnet/activations/__pycache__/activations.cpython-39.pyc b/geffnet/activations/__pycache__/activations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..219d0a36cac083bdf69078613664ad0dd83ec18c Binary files /dev/null and b/geffnet/activations/__pycache__/activations.cpython-39.pyc differ diff --git a/geffnet/activations/__pycache__/activations_jit.cpython-39.pyc b/geffnet/activations/__pycache__/activations_jit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37e726695f717cf66b66662bae145ac3963c935a Binary files /dev/null and b/geffnet/activations/__pycache__/activations_jit.cpython-39.pyc differ diff --git a/geffnet/activations/__pycache__/activations_me.cpython-39.pyc b/geffnet/activations/__pycache__/activations_me.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c5478271118bbd9efab7140ff3864805dc6b5ff Binary files /dev/null and b/geffnet/activations/__pycache__/activations_me.cpython-39.pyc differ diff --git a/geffnet/activations/activations.py b/geffnet/activations/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..bdea692d1397673b2513d898c33edbcb37d94240 --- /dev/null +++ b/geffnet/activations/activations.py @@ -0,0 +1,102 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return mish(x, self.inplace) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + diff --git a/geffnet/activations/activations_jit.py b/geffnet/activations/activations_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..7176b05e779787528a47f20d55d64d4a0f219360 --- /dev/null +++ b/geffnet/activations/activations_jit.py @@ -0,0 +1,79 @@ +""" Activations (jit) + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + +__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', + 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) diff --git a/geffnet/activations/activations_me.py b/geffnet/activations/activations_me.py new file mode 100644 index 0000000000000000000000000000000000000000..e91df5a50fdbe40bc386e2541a4fda743ad95e9a --- /dev/null +++ b/geffnet/activations/activations_me.py @@ -0,0 +1,174 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', + 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + + Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) diff --git a/geffnet/config.py b/geffnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..27d5307fd9ee0246f1e35f41520f17385d23f1dd --- /dev/null +++ b/geffnet/config.py @@ -0,0 +1,123 @@ +""" Global layer config state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False + + +def layer_config_kwargs(kwargs): + """ Consume config kwargs and return contextmgr obj """ + return set_layer_config( + scriptable=kwargs.pop('scriptable', None), + exportable=kwargs.pop('exportable', None), + no_jit=kwargs.pop('no_jit', None)) diff --git a/geffnet/conv2d_layers.py b/geffnet/conv2d_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d8467460c4b36e54c83ce2dcd3ebe91d3432cad2 --- /dev/null +++ b/geffnet/conv2d_layers.py @@ -0,0 +1,304 @@ +""" Conv2D w/ SAME padding, CondConv, MixedConv + +A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and +MobileNetV3 models that maintain weight compatibility with original Tensorflow models. + +Copyright 2020 Ross Wightman +""" +import collections.abc +import math +from functools import partial +from itertools import repeat +from typing import Tuple, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import * + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i: int, k: int, s: int, d: int): + return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _same_pad_arg(input_size, kernel_size, stride, dilation): + ih, iw = input_size + kh, kw = kernel_size + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dSameExport(nn.Conv2d): + """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions + + NOTE: This does not currently work with torch.jit.script + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSameExport, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.pad = None + self.pad_input_size = (0, 0) + + def forward(self, x): + input_size = x.size()[-2:] + if self.pad is None: + pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) + self.pad = nn.ZeroPad2d(pad_arg) + self.pad_input_size = input_size + + if self.pad is not None: + x = self.pad(x) + return F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + else: + # dynamic padding + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + if is_exportable(): + assert not is_scriptable() + return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) + else: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = _pair(padding_val) + self.dilation = _pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out + + +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/geffnet/efficientnet_builder.py b/geffnet/efficientnet_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..95dd63d400e70d70664c5a433a2772363f865e61 --- /dev/null +++ b/geffnet/efficientnet_builder.py @@ -0,0 +1,683 @@ +""" EfficientNet / MobileNetV3 Blocks and Builder + +Copyright 2020 Ross Wightman +""" +import re +from copy import deepcopy + +from .conv2d_layers import * +from geffnet.activations import * + +__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', + 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', + 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', + 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' +] + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +# +# PyTorch defaults are momentum = .1, eps = 1e-5 +# +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, # None == use containing block's activation layer + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + +def make_divisible(v: int, divisor: int = 8, min_value: int = None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): + """Apply drop connect.""" + if not training: + return inputs + + keep_prob = 1 - drop_connect_rate + random_tensor = keep_prob + torch.rand( + (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) + random_tensor.floor_() # binarize + output = inputs.div(keep_prob) * random_tensor + return output + + +class SqueezeExcite(nn.Module): + + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): + super(SqueezeExcite, self).__init__() + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion + factor of 1.0. This is an alternative to having a IR with optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs: int = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() # for jit.script compat + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_connect_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_connect_rate=drop_connect_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) + self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + +class EfficientNetBuilder: + """ Build Trunk Blocks for Efficient/Mobile Networks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_connect_rate = drop_connect_rate + + # updated during build + self.in_chs = None + self.block_idx = 0 + self.block_count = 0 + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba): + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = EdgeResidual(**ba) + elif bt == 'cn': + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def _make_stack(self, stack_args): + blocks = [] + # each stack (stage) contains a list of block arguments + for i, ba in enumerate(stack_args): + if i >= 1: + # only the first block in any stack can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + self.block_idx += 1 # incr global idx (across all stacks) + return nn.Sequential(*blocks) + + def __call__(self, in_chs, block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + self.in_chs = in_chs + self.block_count = sum([len(x) for x in block_args]) + self.block_idx = 0 + blocks = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stack_idx, stack in enumerate(block_args): + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + return blocks + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +def initialize_weight_goog(m, n='', fix_group_fanout=True): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def initialize_weight_default(m, n=''): + if isinstance(m, CondConv2d): + init_fn = get_condconv_initializer(partial( + nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) + init_fn(m.weight) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') diff --git a/geffnet/gen_efficientnet.py b/geffnet/gen_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..cd170d4cc5bed6ca82b61539902b470d3320c691 --- /dev/null +++ b/geffnet/gen_efficientnet.py @@ -0,0 +1,1450 @@ +""" Generic Efficient Networks + +A generic MobileNet class with building blocks to support a variety of models: + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* EfficientNet-Lite + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .config import layer_config_kwargs, is_scriptable +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', + 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', + 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d', + 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', + 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el', + 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', + 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', + 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', + 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', + 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap', + 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap', + 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', + 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', + 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', + 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el', + 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', + 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', + 'tf_efficientnet_lite4', + 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l'] + + +model_urls = { + 'mnasnet_050': None, + 'mnasnet_075': None, + 'mnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', + 'mnasnet_140': None, + 'mnasnet_small': None, + + 'semnasnet_050': None, + 'semnasnet_075': None, + 'semnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', + 'semnasnet_140': None, + + 'mobilenetv2_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', + 'mobilenetv2_110d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', + 'mobilenetv2_120d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', + 'mobilenetv2_140': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', + + 'fbnetc_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + 'spnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + + 'efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', + 'efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + 'efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + 'efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + 'efficientnet_b4': None, + 'efficientnet_b5': None, + 'efficientnet_b6': None, + 'efficientnet_b7': None, + 'efficientnet_b8': None, + 'efficientnet_l2': None, + + 'efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', + 'efficientnet_em': None, + 'efficientnet_el': None, + + 'efficientnet_cc_b0_4e': None, + 'efficientnet_cc_b0_8e': None, + 'efficientnet_cc_b1_8e': None, + + 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', + 'efficientnet_lite1': None, + 'efficientnet_lite2': None, + 'efficientnet_lite3': None, + 'efficientnet_lite4': None, + + 'tf_efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + 'tf_efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + 'tf_efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + 'tf_efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + 'tf_efficientnet_b4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + 'tf_efficientnet_b5': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + 'tf_efficientnet_b6': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + 'tf_efficientnet_b7': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + 'tf_efficientnet_b8': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + + 'tf_efficientnet_b0_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + 'tf_efficientnet_b1_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + 'tf_efficientnet_b2_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + 'tf_efficientnet_b3_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + 'tf_efficientnet_b4_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + 'tf_efficientnet_b5_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + 'tf_efficientnet_b6_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + 'tf_efficientnet_b7_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + 'tf_efficientnet_b8_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + + 'tf_efficientnet_b0_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + 'tf_efficientnet_b1_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + 'tf_efficientnet_b2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + 'tf_efficientnet_b3_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + 'tf_efficientnet_b4_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + 'tf_efficientnet_b5_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + 'tf_efficientnet_b6_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + 'tf_efficientnet_b7_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + 'tf_efficientnet_l2_ns_475': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + 'tf_efficientnet_l2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + + 'tf_efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + 'tf_efficientnet_em': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + 'tf_efficientnet_el': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + + 'tf_efficientnet_cc_b0_4e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + 'tf_efficientnet_cc_b0_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + 'tf_efficientnet_cc_b1_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + + 'tf_efficientnet_lite0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + 'tf_efficientnet_lite1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + 'tf_efficientnet_lite2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + 'tf_efficientnet_lite3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + 'tf_efficientnet_lite4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + + 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', + 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', + 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', + 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', + + 'tf_mixnet_s': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', + 'tf_mixnet_m': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', + 'tf_mixnet_l': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', +} + + +class GenEfficientNet(nn.Module): + """ Generic EfficientNets + + An implementation of mobile optimized networks that covers: + * EfficientNet (B0-B8, L2, CondConv, EdgeTPU) + * MixNet (Small, Medium, and Large, XL) + * MNASNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + weight_init='goog'): + super(GenEfficientNet, self).__init__() + self.drop_rate = drop_rate + + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, + pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type) + self.bn2 = norm_layer(num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(num_features, num_classes) + + for n, m in self.named_modules(): + if weight_init == 'goog': + initialize_weight_goog(m, n) + else: + initialize_weight_default(m, n) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.conv_head, self.bn2, self.act2, + self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.features(x) + x = self.global_pool(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = GenEfficientNet(**model_kwargs) + if pretrained: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + fix_stem=fix_stem_head, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=nn.ReLU6, + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an efficientnet-condconv model.""" + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + channel_multiplier=channel_multiplier, + act_layer=nn.ReLU6, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2. """ + # NOTE for train, drop_rate should be 0.5 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + # NOTE for train set drop_rate=0.2 + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/geffnet/helpers.py b/geffnet/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f83a07d690c7ad681c777c19b1e7a5bb95da007 --- /dev/null +++ b/geffnet/helpers.py @@ -0,0 +1,71 @@ +""" Checkpoint loading / state_dict helpers +Copyright 2020 Ross Wightman +""" +import torch +import os +from collections import OrderedDict +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +def load_checkpoint(model, checkpoint_path): + if checkpoint_path and os.path.isfile(checkpoint_path): + print("=> Loading checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + else: + model.load_state_dict(checkpoint) + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + else: + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_pretrained(model, url, filter_fn=None, strict=True): + if not url: + print("=> Warning: Pretrained model URL is empty, using random initialization.") + return + + state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') + + input_conv = 'conv_stem' + classifier = 'classifier' + in_chans = getattr(model, input_conv).weight.shape[1] + num_classes = getattr(model, classifier).weight.shape[0] + + input_conv_weight = input_conv + '.weight' + pretrained_in_chans = state_dict[input_conv_weight].shape[1] + if in_chans != pretrained_in_chans: + if in_chans == 1: + print('=> Converting pretrained input conv {} from {} to 1 channel'.format( + input_conv_weight, pretrained_in_chans)) + conv1_weight = state_dict[input_conv_weight] + state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) + else: + print('=> Discarding pretrained input conv {} since input channel count != {}'.format( + input_conv_weight, pretrained_in_chans)) + del state_dict[input_conv_weight] + strict = False + + classifier_weight = classifier + '.weight' + pretrained_num_classes = state_dict[classifier_weight].shape[0] + if num_classes != pretrained_num_classes: + print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) + del state_dict[classifier_weight] + del state_dict[classifier + '.bias'] + strict = False + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + model.load_state_dict(state_dict, strict=strict) diff --git a/geffnet/mobilenetv3.py b/geffnet/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..b5966c28f7207e98ee50745b1bc8f3663c650f9d --- /dev/null +++ b/geffnet/mobilenetv3.py @@ -0,0 +1,364 @@ +""" MobileNet-V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .activations import get_act_fn, get_act_layer, HardSwish +from .config import layer_config_kwargs +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', + 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', + 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', + 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100'] + +model_urls = { + 'mobilenetv3_rw': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + 'mobilenetv3_large_075': None, + 'mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', + 'mobilenetv3_large_minimal_100': None, + 'mobilenetv3_small_075': None, + 'mobilenetv3_small_100': None, + 'mobilenetv3_small_minimal_100': None, + 'tf_mobilenetv3_large_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + 'tf_mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + 'tf_mobilenetv3_large_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + 'tf_mobilenetv3_small_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + 'tf_mobilenetv3_small_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + 'tf_mobilenetv3_small_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', +} + + +class MobileNetV3(nn.Module): + """ MobileNet-V3 + + A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the + head convolution without a final batch-norm layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(MobileNetV3, self).__init__() + self.drop_rate = drop_rate + + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.classifier = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if weight_init == 'goog': + initialize_weight_goog(m) + else: + initialize_weight_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.global_pool, self.conv_head, self.act2, + nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = MobileNetV3(**model_kwargs) + if pretrained and model_urls[variant]: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model (RW variant). + + Paper: https://arxiv.org/abs/1905.02244 + + This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the + eventual Tensorflow reference impl but has a few differences: + 1. This model has no bias on the head convolution + 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet + 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer + from their parent block + 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count + + Overall the changes are fairly minor and result in a very small parameter count difference and no + top-1/5 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, # one of my mistakes + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 large/small/minimal models. + + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, act_layer), + se_kwargs=dict( + act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet-V3 RW + Attn: See note in gen function for this variant. + """ + # NOTE for train set drop_rate=0.2 + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75""" + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large (Minimalistic) 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small (Minimalistic) 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model diff --git a/geffnet/model_factory.py b/geffnet/model_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4d46ea8baedaf3d787826eb3bb314b4230514647 --- /dev/null +++ b/geffnet/model_factory.py @@ -0,0 +1,27 @@ +from .config import set_layer_config +from .helpers import load_checkpoint + +from .gen_efficientnet import * +from .mobilenetv3 import * + + +def create_model( + model_name='mnasnet_100', + pretrained=None, + num_classes=1000, + in_chans=3, + checkpoint_path='', + **kwargs): + + model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) + + if model_name in globals(): + create_fn = globals()[model_name] + model = create_fn(**model_kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + if checkpoint_path and not pretrained: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/geffnet/version.py b/geffnet/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a6221b3de7b1490c5e712e8b5fcc94c3d9d04295 --- /dev/null +++ b/geffnet/version.py @@ -0,0 +1 @@ +__version__ = '1.0.2' diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..251d63501865439ca4e829884aecea9037d153db --- /dev/null +++ b/models/README.md @@ -0,0 +1,5 @@ +You can download the following checkpoints and put them in the current file folder: + +1. [stable diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt) +2. [dpt model](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt) +3. [AdaBins_nyu.pt](https://drive.google.com/drive/folders/1nYyaQXOBjNdUJDsmJpcRpu6oE55aQoLA?usp=sharing) \ No newline at end of file