h1t commited on
Commit
94c2073
β€’
1 Parent(s): 6b5cce7

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +381 -13
  2. gradio_app.py +117 -0
  3. scheduling_tcd.py +657 -0
README.md CHANGED
@@ -1,13 +1,381 @@
1
- ---
2
- title: TCD
3
- emoji: πŸ”₯
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.19.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Trajectory Consistency Distillation
2
+
3
+ [![Arxiv](https://img.shields.io/badge/arXiv-2211.15744-b31b1b)]()
4
+ [![Project page](https://img.shields.io/badge/Web-Project%20Page-green)](https://mhh0318.github.io/tcd)
5
+ [![Hugging Face Model](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Model-purple)](https://huggingface.co/h1t/TCD-SDXL-LoRA)
6
+ [![Hugging Face Space](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Space-blue)](https://huggingface.co/spaces/h1t/TCD-SDXL-LoRA)
7
+
8
+ Official Repository of the paper: [Trajectory Consistency Distillation]()
9
+
10
+ ![](./assets/teaser_fig.png)
11
+
12
+ ## πŸ“£ News
13
+ - (πŸ”₯New) 2024/2/29 We provided a demo of TCD on πŸ€— Hugging Face Space. Try it out [here](https://huggingface.co/spaces/h1t/TCD-SDXL-LoRA).
14
+ - (πŸ”₯New) 2024/2/29 We released our model [TCD-SDXL-Lora](https://huggingface.co/h1t/TCD-SDXL-LoRA) in πŸ€— Hugging Face.
15
+ - (πŸ”₯New) 2024/2/29 TCD is now integrated into the 🧨 Diffusers library. Please refer to the [Usage](#usage-anchor) for more information.
16
+
17
+ ## Introduction
18
+
19
+ TCD, inspired by [Consistency Models](https://arxiv.org/abs/2303.01469), is a novel distillation technology that enables the distillation of knowledge from pre-trained diffusion models into a few-step sampler. In this repository, we release the inference code and our model named TCD-SDXL, which is distilled from [SDXL Base 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). We provide the LoRA checkpoint in this πŸ”₯[repository](https://huggingface.co/h1t/TCD-SDXL-LoRA).
20
+
21
+ ⭐ TCD has following advantages:
22
+
23
+ - `High-Quality with Few-Step`: TCD significantly surpasses the previous state-of-the-art few-step text-to-image model [LCM](https://github.com/luosiallen/latent-consistency-model/tree/main) in terms of image quality. Notably, LCM experiences a notable decline in quality at high NFEs. In contrast, _**TCD maintains superior generative quality at high NFEs, even exceeding the performance of DPM-Solver++(2S) with origin SDXL**_.
24
+ ![](./assets/teaser.jpeg)
25
+ <!-- We observed that the images generated with 8 steps by TCD-SDXL are already highly impressive, even outperforming the original SDXL 50-steps generation results. -->
26
+ - `Versatility`: Integrated with LoRA technology, TCD can be directly applied to various models (including the custom Community Models, styled LoRA, ControlNet, IP-Adapter) that share the same backbone, as demonstrated in the [Usage](#usage-anchor).
27
+ ![](./assets/versatility.png)
28
+ - `Avoiding Mode Collapse`: TCD achieves few-step generation without the need for adversarial training, thus circumventing mode collapse caused by the GAN objective.
29
+ In contrast to the concurrent work [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning), which relies on Adversarial Diffusion Distillation, TCD can synthesize results that are more realistic and slightly more diverse, without the presence of "Janus" artifacts.
30
+ ![](./assets/compare_sdxl_lightning.png)
31
+
32
+ For more information, please refer to our paper [Trajectory Consistency Distillation]().
33
+
34
+ <a id="usage-anchor"></a>
35
+
36
+ ## Usage
37
+ To run the model yourself, you can leverage the 🧨 Diffusers library.
38
+ ```bash
39
+ pip install diffusers transformers accelerate peft
40
+ ```
41
+ And then we clone the repo.
42
+ ```bash
43
+ git clone https://github.com/jabir-zheng/TCD.git
44
+ cd TCD
45
+ ```
46
+ Here, we demonstrate the applicability of our TCD LoRA to various models, including [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), [SDXL Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), a community model named [Animagine XL](https://huggingface.co/cagliostrolab/animagine-xl-3.0), a styled LoRA [Papercut](https://huggingface.co/TheLastBen/Papercut_SDXL), pretrained [Depth Controlnet](https://huggingface.co/diffusers/controlnet-depth-sdxl-1.0), [Canny Controlnet](https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0) and [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) to accelerate image generation with high quality in few steps.
47
+
48
+ ### Text-to-Image generation
49
+ ```py
50
+ import torch
51
+ from diffusers import StableDiffusionXLPipeline
52
+ from scheduling_tcd import TCDScheduler
53
+
54
+ device = "cuda"
55
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
56
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
57
+
58
+ pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
59
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
60
+
61
+ pipe.load_lora_weights(tcd_lora_id)
62
+ pipe.fuse_lora()
63
+
64
+ prompt = "Beautiful woman, bubblegum pink, lemon yellow, minty blue, futuristic, high-detail, epic composition, watercolor."
65
+
66
+ image = pipe(
67
+ prompt=prompt,
68
+ num_inference_steps=4,
69
+ guidance_scale=0,
70
+ # Eta (referred to as `gamma` in the paper) is used to control the stochasticity in every step.
71
+ # A value of 0.3 often yields good results.
72
+ # We recommend using a higher eta when increasing the number of inference steps.
73
+ eta=0.3,
74
+ generator=torch.Generator(device=device).manual_seed(0),
75
+ ).images[0]
76
+ ```
77
+ ![](./assets/t2i_tcd.png)
78
+
79
+ ### Inpainting
80
+ ```py
81
+ import torch
82
+ from diffusers import AutoPipelineForInpainting
83
+ from diffusers.utils import load_image, make_image_grid
84
+ from scheduling_tcd import TCDScheduler
85
+
86
+ device = "cuda"
87
+ base_model_id = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
88
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
89
+
90
+ pipe = AutoPipelineForInpainting.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
91
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
92
+
93
+ pipe.load_lora_weights(tcd_lora_id)
94
+ pipe.fuse_lora()
95
+
96
+ img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
97
+ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
98
+
99
+ init_image = load_image(img_url).resize((1024, 1024))
100
+ mask_image = load_image(mask_url).resize((1024, 1024))
101
+
102
+ prompt = "a tiger sitting on a park bench"
103
+
104
+ image = pipe(
105
+ prompt=prompt,
106
+ image=init_image,
107
+ mask_image=mask_image,
108
+ num_inference_steps=8,
109
+ guidance_scale=0,
110
+ eta=0.3, # Eta (referred to as `gamma` in the paper) is used to control the stochasticity in every step. A value of 0.3 often yields good results.
111
+ strength=0.99, # make sure to use `strength` below 1.0
112
+ generator=torch.Generator(device=device).manual_seed(0),
113
+ ).images[0]
114
+
115
+ grid_image = make_image_grid([init_image, mask_image, image], rows=1, cols=3)
116
+ ```
117
+ ![](./assets/inpainting_tcd.png)
118
+
119
+ ### Versatile for Community Models
120
+ ```py
121
+ import torch
122
+ from diffusers import StableDiffusionXLPipeline
123
+ from scheduling_tcd import TCDScheduler
124
+
125
+ device = "cuda"
126
+ base_model_id = "cagliostrolab/animagine-xl-3.0"
127
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
128
+
129
+ pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
130
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
131
+
132
+ pipe.load_lora_weights(tcd_lora_id)
133
+ pipe.fuse_lora()
134
+
135
+ prompt = "A man, clad in a meticulously tailored military uniform, stands with unwavering resolve. The uniform boasts intricate details, and his eyes gleam with determination. Strands of vibrant, windswept hair peek out from beneath the brim of his cap."
136
+
137
+ image = pipe(
138
+ prompt=prompt,
139
+ num_inference_steps=8,
140
+ guidance_scale=0,
141
+ # Eta (referred to as `gamma` in the paper) is used to control the stochasticity in every step.
142
+ # A value of 0.3 often yields good results.
143
+ # We recommend using a higher eta when increasing the number of inference steps.
144
+ eta=0.3,
145
+ generator=torch.Generator(device=device).manual_seed(0),
146
+ ).images[0]
147
+ ```
148
+ ![](./assets/animagine_xl.png)
149
+
150
+ ### Combine with styled LoRA
151
+ ```py
152
+ import torch
153
+ from diffusers import StableDiffusionXLPipeline
154
+ from scheduling_tcd import TCDScheduler
155
+
156
+ device = "cuda"
157
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
158
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
159
+ styled_lora_id = "TheLastBen/Papercut_SDXL"
160
+
161
+ pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
162
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
163
+
164
+ pipe.load_lora_weights(tcd_lora_id, adapter_name="tcd")
165
+ pipe.load_lora_weights(styled_lora_id, adapter_name="style")
166
+ pipe.set_adapters(["tcd", "style"], adapter_weights=[1.0, 1.0])
167
+
168
+ prompt = "papercut of a winter mountain, snow"
169
+
170
+ image = pipe(
171
+ prompt=prompt,
172
+ num_inference_steps=4,
173
+ guidance_scale=0,
174
+ # Eta (referred to as `gamma` in the paper) is used to control the stochasticity in every step.
175
+ # A value of 0.3 often yields good results.
176
+ # We recommend using a higher eta when increasing the number of inference steps.
177
+ eta=0.3,
178
+ generator=torch.Generator(device=device).manual_seed(0),
179
+ ).images[0]
180
+ ```
181
+ ![](./assets/styled_lora.png)
182
+
183
+ ### Compatibility with ControlNet
184
+ #### Depth ControlNet
185
+ ```py
186
+ import torch
187
+ import numpy as np
188
+ from PIL import Image
189
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
190
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
191
+ from diffusers.utils import load_image, make_image_grid
192
+ from scheduling_tcd import TCDScheduler
193
+
194
+ device = "cuda"
195
+ depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
196
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
197
+
198
+ def get_depth_map(image):
199
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
200
+ with torch.no_grad(), torch.autocast(device):
201
+ depth_map = depth_estimator(image).predicted_depth
202
+
203
+ depth_map = torch.nn.functional.interpolate(
204
+ depth_map.unsqueeze(1),
205
+ size=(1024, 1024),
206
+ mode="bicubic",
207
+ align_corners=False,
208
+ )
209
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
210
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
211
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
212
+ image = torch.cat([depth_map] * 3, dim=1)
213
+
214
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
215
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
216
+ return image
217
+
218
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
219
+ controlnet_id = "diffusers/controlnet-depth-sdxl-1.0"
220
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
221
+
222
+ controlnet = ControlNetModel.from_pretrained(
223
+ controlnet_id,
224
+ torch_dtype=torch.float16,
225
+ variant="fp16",
226
+ ).to(device)
227
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
228
+ base_model_id,
229
+ controlnet=controlnet,
230
+ torch_dtype=torch.float16,
231
+ variant="fp16",
232
+ ).to(device)
233
+ pipe.enable_model_cpu_offload()
234
+
235
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
236
+
237
+ pipe.load_lora_weights(tcd_lora_id)
238
+ pipe.fuse_lora()
239
+
240
+ prompt = "stormtrooper lecture, photorealistic"
241
+
242
+ image = load_image("https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png")
243
+ depth_image = get_depth_map(image)
244
+
245
+ controlnet_conditioning_scale = 0.5 # recommended for good generalization
246
+
247
+ image = pipe(
248
+ prompt,
249
+ image=depth_image,
250
+ num_inference_steps=4,
251
+ guidance_scale=0,
252
+ eta=0.3, # A parameter (referred to as `gamma` in the paper) is used to control the stochasticity in every step. A value of 0.3 often yields good results.
253
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
254
+ generator=torch.Generator(device=device).manual_seed(0),
255
+ ).images[0]
256
+
257
+ grid_image = make_image_grid([depth_image, image], rows=1, cols=2)
258
+ ```
259
+ ![](./assets/controlnet_depth_tcd.png)
260
+
261
+ #### Canny ControlNet
262
+ ```py
263
+ import torch
264
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
265
+ from diffusers.utils import load_image, make_image_grid
266
+ from scheduling_tcd import TCDScheduler
267
+
268
+ device = "cuda"
269
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
270
+ controlnet_id = "diffusers/controlnet-canny-sdxl-1.0"
271
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
272
+
273
+ controlnet = ControlNetModel.from_pretrained(
274
+ controlnet_id,
275
+ torch_dtype=torch.float16,
276
+ variant="fp16",
277
+ ).to(device)
278
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
279
+ base_model_id,
280
+ controlnet=controlnet,
281
+ torch_dtype=torch.float16,
282
+ variant="fp16",
283
+ ).to(device)
284
+ pipe.enable_model_cpu_offload()
285
+
286
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
287
+
288
+ pipe.load_lora_weights(tcd_lora_id)
289
+ pipe.fuse_lora()
290
+
291
+ prompt = "ultrarealistic shot of a furry blue bird"
292
+
293
+ canny_image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png")
294
+
295
+ controlnet_conditioning_scale = 0.5 # recommended for good generalization
296
+
297
+ image = pipe(
298
+ prompt,
299
+ image=canny_image,
300
+ num_inference_steps=4,
301
+ guidance_scale=0,
302
+ eta=0.3, # A parameter (referred to as `gamma` in the paper) is used to control the stochasticity in every step. A value of 0.3 often yields good results.
303
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
304
+ generator=torch.Generator(device=device).manual_seed(0),
305
+ ).images[0]
306
+
307
+ grid_image = make_image_grid([canny_image, image], rows=1, cols=2)
308
+ ```
309
+
310
+ ![](./assets/controlnet_canny_tcd.png)
311
+
312
+ ### Compatibility with IP-Adapter
313
+ ⚠️ Please refer to the official [repository](https://github.com/tencent-ailab/IP-Adapter/tree/main) for instructions on installing dependencies for IP-Adapter.
314
+ ```py
315
+ import torch
316
+ from diffusers import StableDiffusionXLPipeline
317
+ from diffusers.utils import load_image, make_image_grid
318
+
319
+ from ip_adapter import IPAdapterXL
320
+ from scheduling_tcd import TCDScheduler
321
+
322
+ device = "cuda"
323
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
324
+ image_encoder_path = "sdxl_models/image_encoder"
325
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
326
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
327
+
328
+ pipe = StableDiffusionXLPipeline.from_pretrained(
329
+ base_model_path,
330
+ torch_dtype=torch.float16,
331
+ variant="fp16"
332
+ )
333
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
334
+
335
+ pipe.load_lora_weights(tcd_lora_id)
336
+ pipe.fuse_lora()
337
+
338
+ ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
339
+
340
+ ref_image = load_image("https://raw.githubusercontent.com/tencent-ailab/IP-Adapter/main/assets/images/woman.png").resize((512, 512))
341
+
342
+ prompt = "best quality, high quality, wearing sunglasses"
343
+
344
+ image = ip_model.generate(
345
+ pil_image=ref_image,
346
+ prompt=prompt,
347
+ scale=0.5,
348
+ num_samples=1,
349
+ num_inference_steps=4,
350
+ guidance_scale=0,
351
+ eta=0.3, # A parameter (referred to as `gamma` in the paper) is used to control the stochasticity in every step. A value of 0.3 often yields good results.
352
+ seed=0,
353
+ )[0]
354
+
355
+ grid_image = make_image_grid([ref_image, image], rows=1, cols=2)
356
+ ```
357
+ ![](./assets/ip_adapter.png)
358
+
359
+ ### Local Gradio Demo
360
+ Install the `gradio` library first,
361
+ ```bash
362
+ pip install gradio==3.50.2
363
+ ```
364
+ then local gradio demo can be launched by:
365
+ ```py
366
+ python gradio_app.py
367
+ ```
368
+ ![](./assets/gradio_demo.png)
369
+
370
+ ## Citation
371
+ ```bibtex
372
+ @article{zheng2024trajectory,
373
+ title = {Trajectory Consistency Distillation},
374
+ author = {Zheng, Jianbin and Hu, Minghui and Fan, Zhongyi and Wang, Chaoyue and Ding, Changxing and Tao, Dacheng and Cham, Tat-Jen},
375
+ journal = {arXiv},
376
+ year = {2024},
377
+ }
378
+ ```
379
+
380
+ ## Acknowledgments
381
+ This codebase heavily relies on the πŸ€—[Diffusers](https://github.com/huggingface/diffusers) library and [LCM](https://github.com/luosiallen/latent-consistency-model).
gradio_app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from diffusers import StableDiffusionXLPipeline
6
+
7
+ from scheduling_tcd import TCDScheduler
8
+
9
+ device = "cuda"
10
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
11
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
12
+
13
+ pipe = StableDiffusionXLPipeline.from_pretrained(
14
+ base_model_id,
15
+ torch_dtype=torch.float16,
16
+ variant="fp16"
17
+ ).to(device)
18
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
19
+
20
+ pipe.load_lora_weights(tcd_lora_id)
21
+ pipe.fuse_lora()
22
+
23
+
24
+ def inference(prompt, num_inference_steps=4, seed=-1, eta=0.3):
25
+ if seed is None or seed == '' or seed == -1:
26
+ seed = int(random.randrange(4294967294))
27
+ generator = torch.Generator(device=device).manual_seed(int(seed))
28
+ image = pipe(
29
+ prompt=prompt,
30
+ num_inference_steps=num_inference_steps,
31
+ guidance_scale=0,
32
+ eta=eta,
33
+ generator=generator,
34
+ ).images[0]
35
+ return image
36
+
37
+
38
+ # Define style
39
+ title = "<h1 style='text-align: center'>Trajectory Consistency Distillation</h1>"
40
+ description = "Official πŸ€— Gradio demo for Trajectory Consistency Distillation"
41
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/' target='_blank'>Trajectory Consistency Distillation</a> | <a href='https://github.com/jabir-zheng/TCD' target='_blank'>Github Repo</a></p>"
42
+
43
+
44
+ default_prompt = "Painting of the orange cat Otto von Garfield, Count of Bismarck-SchΓΆnhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna."
45
+ examples = [
46
+ [
47
+ "Beautiful woman, bubblegum pink, lemon yellow, minty blue, futuristic, high-detail, epic composition, watercolor.",
48
+ 4
49
+ ],
50
+ [
51
+ "Beautiful man, bubblegum pink, lemon yellow, minty blue, futuristic, high-detail, epic composition, watercolor.",
52
+ 8
53
+ ],
54
+ [
55
+ "Painting of the orange cat Otto von Garfield, Count of Bismarck-SchΓΆnhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.",
56
+ 16
57
+ ],
58
+ [
59
+ "closeup portrait of 1 Persian princess, royal clothing, makeup, jewelry, wind-blown long hair, symmetric, desert, sands, dusty and foggy, sand storm, winds bokeh, depth of field, centered.",
60
+ 16
61
+ ],
62
+ ]
63
+
64
+ outputs = gr.Label(label='Generated Images')
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown(f'# {title}\n### {description}')
68
+
69
+ with gr.Row():
70
+ with gr.Column():
71
+ prompt = gr.Textbox(label='Prompt', value=default_prompt)
72
+ num_inference_steps = gr.Slider(
73
+ label='Inference steps',
74
+ minimum=4,
75
+ maximum=16,
76
+ value=4,
77
+ step=1,
78
+ )
79
+
80
+ with gr.Accordion("Advanced Options", visible=False):
81
+ with gr.Row():
82
+ with gr.Column():
83
+ seed = gr.Number(label="Random Seed", value=-1)
84
+ with gr.Column():
85
+ eta = gr.Slider(
86
+ label='Gamma',
87
+ minimum=0.,
88
+ maximum=1.,
89
+ value=0.3,
90
+ step=0.1,
91
+ )
92
+
93
+ with gr.Row():
94
+ clear = gr.ClearButton(
95
+ components=[prompt, num_inference_steps, seed, eta])
96
+ submit = gr.Button(value='Submit')
97
+
98
+ examples = gr.Examples(
99
+ label="Quick Examples",
100
+ examples=examples,
101
+ inputs=[prompt, num_inference_steps, 0, 0.3],
102
+ outputs="outputs", # 适当调整歀倄
103
+ cache_examples=False
104
+ )
105
+
106
+ with gr.Column():
107
+ outputs = gr.Image(label='Generated Images')
108
+
109
+ gr.Markdown(f'{article}')
110
+
111
+ submit.click(
112
+ fn=inference,
113
+ inputs=[prompt, num_inference_steps, seed, eta],
114
+ outputs=outputs,
115
+ )
116
+
117
+ demo.launch()
scheduling_tcd.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class TCDSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_noised_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted noised sample `(x_{s})` based on the model output from the current timestep.
45
+ """
46
+
47
+ prev_sample: torch.FloatTensor
48
+ pred_noised_sample: Optional[torch.FloatTensor] = None
49
+
50
+
51
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
52
+ def betas_for_alpha_bar(
53
+ num_diffusion_timesteps,
54
+ max_beta=0.999,
55
+ alpha_transform_type="cosine",
56
+ ):
57
+ """
58
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
59
+ (1-beta) over time from t = [0,1].
60
+
61
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
62
+ to that part of the diffusion process.
63
+
64
+
65
+ Args:
66
+ num_diffusion_timesteps (`int`): the number of betas to produce.
67
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
68
+ prevent singularities.
69
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
70
+ Choose from `cosine` or `exp`
71
+
72
+ Returns:
73
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
74
+ """
75
+ if alpha_transform_type == "cosine":
76
+
77
+ def alpha_bar_fn(t):
78
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
79
+
80
+ elif alpha_transform_type == "exp":
81
+
82
+ def alpha_bar_fn(t):
83
+ return math.exp(t * -12.0)
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
87
+
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
+ return torch.tensor(betas, dtype=torch.float32)
94
+
95
+
96
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
97
+ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
98
+ """
99
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
100
+
101
+
102
+ Args:
103
+ betas (`torch.FloatTensor`):
104
+ the betas that the scheduler is being initialized with.
105
+
106
+ Returns:
107
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
108
+ """
109
+ # Convert betas to alphas_bar_sqrt
110
+ alphas = 1.0 - betas
111
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
112
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
113
+
114
+ # Store old values.
115
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
116
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
117
+
118
+ # Shift so the last timestep is zero.
119
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
120
+
121
+ # Scale so the first timestep is back to the old value.
122
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
123
+
124
+ # Convert alphas_bar_sqrt to betas
125
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
126
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
127
+ alphas = torch.cat([alphas_bar[0:1], alphas])
128
+ betas = 1 - alphas
129
+
130
+ return betas
131
+
132
+
133
+ class TCDScheduler(SchedulerMixin, ConfigMixin):
134
+ """
135
+ `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency Distillation`,
136
+ extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
137
+
138
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
139
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
140
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
141
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
142
+
143
+ Args:
144
+ num_train_timesteps (`int`, defaults to 1000):
145
+ The number of diffusion steps to train the model.
146
+ beta_start (`float`, defaults to 0.0001):
147
+ The starting `beta` value of inference.
148
+ beta_end (`float`, defaults to 0.02):
149
+ The final `beta` value.
150
+ beta_schedule (`str`, defaults to `"linear"`):
151
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
152
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
153
+ trained_betas (`np.ndarray`, *optional*):
154
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
155
+ original_inference_steps (`int`, *optional*, defaults to 50):
156
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
157
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
158
+ clip_sample (`bool`, defaults to `True`):
159
+ Clip the predicted sample for numerical stability.
160
+ clip_sample_range (`float`, defaults to 1.0):
161
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
162
+ set_alpha_to_one (`bool`, defaults to `True`):
163
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
164
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
165
+ otherwise it uses the alpha value at step 0.
166
+ steps_offset (`int`, defaults to 0):
167
+ An offset added to the inference steps. You can use a combination of `offset=1` and
168
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
169
+ Diffusion.
170
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
171
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
172
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
173
+ Video](https://imagen.research.google/video/paper.pdf) paper).
174
+ thresholding (`bool`, defaults to `False`):
175
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
176
+ as Stable Diffusion.
177
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
178
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
179
+ sample_max_value (`float`, defaults to 1.0):
180
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
181
+ timestep_spacing (`str`, defaults to `"leading"`):
182
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
183
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
184
+ timestep_scaling (`float`, defaults to 10.0):
185
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
186
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
187
+ error at the default of `10.0` is already pretty small).
188
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
189
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
190
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
191
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
192
+ """
193
+
194
+ order = 1
195
+
196
+ @register_to_config
197
+ def __init__(
198
+ self,
199
+ num_train_timesteps: int = 1000,
200
+ beta_start: float = 0.00085,
201
+ beta_end: float = 0.012,
202
+ beta_schedule: str = "scaled_linear",
203
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
204
+ original_inference_steps: int = 50,
205
+ clip_sample: bool = False,
206
+ clip_sample_range: float = 1.0,
207
+ set_alpha_to_one: bool = True,
208
+ steps_offset: int = 0,
209
+ prediction_type: str = "epsilon",
210
+ thresholding: bool = False,
211
+ dynamic_thresholding_ratio: float = 0.995,
212
+ sample_max_value: float = 1.0,
213
+ timestep_spacing: str = "leading",
214
+ timestep_scaling: float = 10.0,
215
+ rescale_betas_zero_snr: bool = False,
216
+ ):
217
+ if trained_betas is not None:
218
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
219
+ elif beta_schedule == "linear":
220
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
221
+ elif beta_schedule == "scaled_linear":
222
+ # this schedule is very specific to the latent diffusion model.
223
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
224
+ elif beta_schedule == "squaredcos_cap_v2":
225
+ # Glide cosine schedule
226
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
227
+ else:
228
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
229
+
230
+ # Rescale for zero SNR
231
+ if rescale_betas_zero_snr:
232
+ self.betas = rescale_zero_terminal_snr(self.betas)
233
+
234
+ self.alphas = 1.0 - self.betas
235
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
236
+
237
+ # At every step in ddim, we are looking into the previous alphas_cumprod
238
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
239
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
240
+ # whether we use the final alpha of the "non-previous" one.
241
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
242
+
243
+ # standard deviation of the initial noise distribution
244
+ self.init_noise_sigma = 1.0
245
+
246
+ # setable values
247
+ self.num_inference_steps = None
248
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
249
+ self.custom_timesteps = False
250
+
251
+ self._step_index = None
252
+
253
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
254
+ def _init_step_index(self, timestep):
255
+ if isinstance(timestep, torch.Tensor):
256
+ timestep = timestep.to(self.timesteps.device)
257
+
258
+ index_candidates = (self.timesteps == timestep).nonzero()
259
+
260
+ # The sigma index that is taken for the **very** first `step`
261
+ # is always the second index (or the last index if there is only 1)
262
+ # This way we can ensure we don't accidentally skip a sigma in
263
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
264
+ if len(index_candidates) > 1:
265
+ step_index = index_candidates[1]
266
+ else:
267
+ step_index = index_candidates[0]
268
+
269
+ self._step_index = step_index.item()
270
+
271
+ @property
272
+ def step_index(self):
273
+ return self._step_index
274
+
275
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
276
+ """
277
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
278
+ current timestep.
279
+
280
+ Args:
281
+ sample (`torch.FloatTensor`):
282
+ The input sample.
283
+ timestep (`int`, *optional*):
284
+ The current timestep in the diffusion chain.
285
+ Returns:
286
+ `torch.FloatTensor`:
287
+ A scaled input sample.
288
+ """
289
+ return sample
290
+
291
+ def _get_variance(self, timestep, prev_timestep):
292
+ alpha_prod_t = self.alphas_cumprod[timestep]
293
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
294
+ beta_prod_t = 1 - alpha_prod_t
295
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
296
+
297
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
298
+
299
+ return variance
300
+
301
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
302
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
303
+ """
304
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
305
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
306
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
307
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
308
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
309
+
310
+ https://arxiv.org/abs/2205.11487
311
+ """
312
+ dtype = sample.dtype
313
+ batch_size, channels, *remaining_dims = sample.shape
314
+
315
+ if dtype not in (torch.float32, torch.float64):
316
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
317
+
318
+ # Flatten sample for doing quantile calculation along each image
319
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
320
+
321
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
322
+
323
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
324
+ s = torch.clamp(
325
+ s, min=1, max=self.config.sample_max_value
326
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
327
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
328
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
329
+
330
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
331
+ sample = sample.to(dtype)
332
+
333
+ return sample
334
+
335
+ def set_timesteps(
336
+ self,
337
+ num_inference_steps: Optional[int] = None,
338
+ device: Union[str, torch.device] = None,
339
+ original_inference_steps: Optional[int] = None,
340
+ timesteps: Optional[List[int]] = None,
341
+ strength: int = 1.0,
342
+ ):
343
+ """
344
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
345
+
346
+ Args:
347
+ num_inference_steps (`int`, *optional*):
348
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
349
+ `timesteps` must be `None`.
350
+ device (`str` or `torch.device`, *optional*):
351
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
352
+ original_inference_steps (`int`, *optional*):
353
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
354
+ schedule (which is different from the standard `diffusers` implementation). We will then take
355
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
356
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
357
+ timesteps (`List[int]`, *optional*):
358
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
359
+ timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
360
+ schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
361
+ """
362
+ # 0. Check inputs
363
+ if num_inference_steps is None and timesteps is None:
364
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
365
+
366
+ if num_inference_steps is not None and timesteps is not None:
367
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
368
+
369
+ # 1. Calculate the TCD original training/distillation timestep schedule.
370
+ original_steps = (
371
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
372
+ )
373
+
374
+ if original_steps is not None:
375
+ if original_steps > self.config.num_train_timesteps:
376
+ raise ValueError(
377
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
378
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
379
+ f" maximal {self.config.num_train_timesteps} timesteps."
380
+ )
381
+ # TCD Timesteps Setting
382
+ # The skipping step parameter k from the paper.
383
+ k = self.config.num_train_timesteps // original_steps
384
+ # TCD Training/Distillation Steps Schedule
385
+ tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
386
+ else:
387
+ tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps * strength))))
388
+
389
+ # 2. Calculate the TCD inference timestep schedule.
390
+ if timesteps is not None:
391
+ # 2.1 Handle custom timestep schedules.
392
+ train_timesteps = set(tcd_origin_timesteps)
393
+ non_train_timesteps = []
394
+ for i in range(1, len(timesteps)):
395
+ if timesteps[i] >= timesteps[i - 1]:
396
+ raise ValueError("`custom_timesteps` must be in descending order.")
397
+
398
+ if timesteps[i] not in train_timesteps:
399
+ non_train_timesteps.append(timesteps[i])
400
+
401
+ if timesteps[0] >= self.config.num_train_timesteps:
402
+ raise ValueError(
403
+ f"`timesteps` must start before `self.config.train_timesteps`:"
404
+ f" {self.config.num_train_timesteps}."
405
+ )
406
+
407
+ # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
408
+ if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
409
+ logger.warning(
410
+ f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
411
+ f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
412
+ f" unexpected results when using this timestep schedule."
413
+ )
414
+
415
+ # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
416
+ if non_train_timesteps:
417
+ logger.warning(
418
+ f"The custom timestep schedule contains the following timesteps which are not on the original"
419
+ f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
420
+ f" when using this timestep schedule."
421
+ )
422
+
423
+ # Raise warning if custom timestep schedule is longer than original_steps
424
+ if original_steps is not None:
425
+ if len(timesteps) > original_steps:
426
+ logger.warning(
427
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
428
+ f" the length of the timestep schedule used for training: {original_steps}. You may get some"
429
+ f" unexpected results when using this timestep schedule."
430
+ )
431
+ else:
432
+ if len(timesteps) > self.config.num_train_timesteps:
433
+ logger.warning(
434
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
435
+ f" the length of the timestep schedule used for training: {self.config.num_train_timesteps}. You may get some"
436
+ f" unexpected results when using this timestep schedule."
437
+ )
438
+
439
+ timesteps = np.array(timesteps, dtype=np.int64)
440
+ self.num_inference_steps = len(timesteps)
441
+ self.custom_timesteps = True
442
+
443
+ # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
444
+ init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
445
+ t_start = max(self.num_inference_steps - init_timestep, 0)
446
+ timesteps = timesteps[t_start * self.order :]
447
+ # TODO: also reset self.num_inference_steps?
448
+ else:
449
+ # 2.2 Create the "standard" TCD inference timestep schedule.
450
+ if num_inference_steps > self.config.num_train_timesteps:
451
+ raise ValueError(
452
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
453
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
454
+ f" maximal {self.config.num_train_timesteps} timesteps."
455
+ )
456
+
457
+ if original_steps is not None:
458
+ skipping_step = len(tcd_origin_timesteps) // num_inference_steps
459
+
460
+ if skipping_step < 1:
461
+ raise ValueError(
462
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
463
+ )
464
+
465
+ self.num_inference_steps = num_inference_steps
466
+
467
+ if original_steps is not None:
468
+ if num_inference_steps > original_steps:
469
+ raise ValueError(
470
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
471
+ f" {original_steps} because the final timestep schedule will be a subset of the"
472
+ f" `original_inference_steps`-sized initial timestep schedule."
473
+ )
474
+ else:
475
+ if num_inference_steps > self.config.num_train_timesteps:
476
+ raise ValueError(
477
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:"
478
+ f" {self.config.num_train_timesteps} because the final timestep schedule will be a subset of the"
479
+ f" `num_train_timesteps`-sized initial timestep schedule."
480
+ )
481
+
482
+ # TCD Inference Steps Schedule
483
+ tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
484
+ # Select (approximately) evenly spaced indices from tcd_origin_timesteps.
485
+ inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
486
+ inference_indices = np.floor(inference_indices).astype(np.int64)
487
+ timesteps = tcd_origin_timesteps[inference_indices]
488
+
489
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
490
+
491
+ self._step_index = None
492
+
493
+ def step(
494
+ self,
495
+ model_output: torch.FloatTensor,
496
+ timestep: int,
497
+ sample: torch.FloatTensor,
498
+ eta: float,
499
+ generator: Optional[torch.Generator] = None,
500
+ return_dict: bool = True,
501
+ ) -> Union[TCDSchedulerOutput, Tuple]:
502
+ """
503
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
504
+ process from the learned model outputs (most often the predicted noise).
505
+
506
+ Args:
507
+ model_output (`torch.FloatTensor`):
508
+ The direct output from learned diffusion model.
509
+ timestep (`int`):
510
+ The current discrete timestep in the diffusion chain.
511
+ sample (`torch.FloatTensor`):
512
+ A current instance of a sample created by the diffusion process.
513
+ eta (`float`):
514
+ A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every step.
515
+ When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
516
+ generator (`torch.Generator`, *optional*):
517
+ A random number generator.
518
+ return_dict (`bool`, *optional*, defaults to `True`):
519
+ Whether or not to return a [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] or `tuple`.
520
+ Returns:
521
+ [`~schedulers.scheduling_utils.TCDSchedulerOutput`] or `tuple`:
522
+ If return_dict is `True`, [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] is returned, otherwise a
523
+ tuple is returned where the first element is the sample tensor.
524
+ """
525
+ if self.num_inference_steps is None:
526
+ raise ValueError(
527
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
528
+ )
529
+
530
+ if self.step_index is None:
531
+ self._init_step_index(timestep)
532
+
533
+ # 1. get previous step value
534
+ prev_step_index = self.step_index + 1
535
+ if prev_step_index < len(self.timesteps):
536
+ prev_timestep = self.timesteps[prev_step_index]
537
+ else:
538
+ prev_timestep = torch.tensor(0)
539
+
540
+ timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long)
541
+
542
+ # 2. compute alphas, betas
543
+ alpha_prod_t = self.alphas_cumprod[timestep]
544
+ beta_prod_t = 1 - alpha_prod_t
545
+
546
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
547
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
548
+
549
+ alpha_prod_s = self.alphas_cumprod[timestep_s] if timestep_s >= 0 else self.final_alpha_cumprod
550
+ beta_prod_s = 1 - alpha_prod_s
551
+
552
+ # 3. Compute the predicted noised sample x_s based on the model parameterization
553
+ if self.config.prediction_type == "epsilon": # noise-prediction
554
+ pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
555
+ pred_epsilon = model_output
556
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
557
+ elif self.config.prediction_type == "sample": # x-prediction
558
+ pred_original_sample = model_output
559
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
560
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
561
+ elif self.config.prediction_type == "v_prediction": # v-prediction
562
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
563
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
564
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
565
+ else:
566
+ raise ValueError(
567
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
568
+ " `v_prediction` for `TCDScheduler`."
569
+ )
570
+
571
+ # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference
572
+ # Noise is not used on the final timestep of the timestep schedule.
573
+ # This also means that noise is not used for one-step sampling.
574
+ # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step.
575
+ # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
576
+ if eta > 0:
577
+ if self.step_index != self.num_inference_steps - 1:
578
+ noise = randn_tensor(
579
+ model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype
580
+ )
581
+ prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (1 - alpha_prod_t_prev / alpha_prod_s).sqrt() * noise
582
+ else:
583
+ prev_sample = pred_noised_sample
584
+ else:
585
+ prev_sample = pred_noised_sample
586
+
587
+ # upon completion increase step index by one
588
+ self._step_index += 1
589
+
590
+ if not return_dict:
591
+ return (prev_sample, pred_noised_sample)
592
+
593
+ return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
594
+
595
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
596
+ def add_noise(
597
+ self,
598
+ original_samples: torch.FloatTensor,
599
+ noise: torch.FloatTensor,
600
+ timesteps: torch.IntTensor,
601
+ ) -> torch.FloatTensor:
602
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
603
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
604
+ timesteps = timesteps.to(original_samples.device)
605
+
606
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
607
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
608
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
609
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
610
+
611
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
612
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
613
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
614
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
615
+
616
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
617
+ return noisy_samples
618
+
619
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
620
+ def get_velocity(
621
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
622
+ ) -> torch.FloatTensor:
623
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
624
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
625
+ timesteps = timesteps.to(sample.device)
626
+
627
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
628
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
629
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
630
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
631
+
632
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
633
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
634
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
635
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
636
+
637
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
638
+ return velocity
639
+
640
+ def __len__(self):
641
+ return self.config.num_train_timesteps
642
+
643
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
644
+ def previous_timestep(self, timestep):
645
+ if self.custom_timesteps:
646
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
647
+ if index == self.timesteps.shape[0] - 1:
648
+ prev_t = torch.tensor(-1)
649
+ else:
650
+ prev_t = self.timesteps[index + 1]
651
+ else:
652
+ num_inference_steps = (
653
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
654
+ )
655
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
656
+
657
+ return prev_t