xco2 commited on
Commit
ebfe12f
1 Parent(s): 98545cc
app.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import time, os
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm, trange
8
+ from PIL import Image
9
+
10
+
11
+ def random_clip(x, min=-1.5, max=1.5):
12
+ if isinstance(x, np.ndarray):
13
+ return np.clip(x, min, max)
14
+ elif isinstance(x, torch.Tensor):
15
+ return torch.clip(x, min, max)
16
+ else:
17
+ raise TypeError(f"type of x is {type(x)}")
18
+
19
+
20
+ class Sampler:
21
+ def __init__(self, device, normal_t):
22
+ self.device = device
23
+ self.total_step = 1000
24
+ self.normal_t = normal_t
25
+
26
+ self.afas_cumprod, self.betas = self.get_afa_bars("scaled_linear", # cosine,linear,scaled_linear
27
+ self.total_step)
28
+ self.afas_cumprod = torch.Tensor(self.afas_cumprod).to(self.device)
29
+ self.betas = torch.Tensor(self.betas).to(self.device)
30
+
31
+ def betas_for_alpha_bar(self, num_diffusion_timesteps, alpha_bar, max_beta=0.999):
32
+ """
33
+ Create a beta schedule that discretizes the given alpha_t_bar function,
34
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
35
+
36
+ :param num_diffusion_timesteps: the number of betas to produce.
37
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
38
+ produces the cumulative product of (1-beta) up to that
39
+ part of the diffusion process.
40
+ :param max_beta: the maximum beta to use; use values lower than 1 to
41
+ prevent singularities.
42
+ """
43
+ betas = []
44
+ for i in range(num_diffusion_timesteps):
45
+ t1 = i / num_diffusion_timesteps
46
+ t2 = (i + 1) / num_diffusion_timesteps
47
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
48
+ return np.array(betas)
49
+
50
+ def get_named_beta_schedule(self, schedule_name, num_diffusion_timesteps):
51
+ """
52
+ Get a pre-defined beta schedule for the given name.
53
+
54
+ The beta schedule library consists of beta schedules which remain similar
55
+ in the limit of num_diffusion_timesteps.
56
+ Beta schedules may be added, but should not be removed or changed once
57
+ they are committed to maintain backwards compatibility.
58
+ """
59
+ if schedule_name == "linear":
60
+ # Linear schedule from Ho et al, extended to work for any number of
61
+ # diffusion steps.
62
+ scale = 1000 / num_diffusion_timesteps
63
+ beta_start = scale * 0.0001
64
+ beta_end = scale * 0.02
65
+ return np.linspace(
66
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
67
+ )
68
+ elif schedule_name == "scaled_linear":
69
+ scale = 1000 / num_diffusion_timesteps
70
+ beta_start = scale * 0.0001
71
+ beta_end = scale * 0.02
72
+ return np.linspace(
73
+ beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
74
+ elif schedule_name == "cosine":
75
+ return self.betas_for_alpha_bar(
76
+ num_diffusion_timesteps,
77
+ lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2,
78
+ )
79
+ else:
80
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
81
+
82
+ def get_afa_bars(self, beta_schedule_name, total_step):
83
+ """
84
+ 生成afa bar的列表,列表长度为total_step
85
+ :param beta_schedule_name: beta_schedule
86
+ :return: afa_bars和betas
87
+ """
88
+
89
+ # if linear:
90
+ # # 线性
91
+ # betas = np.linspace(1e-5, 0.1, self.total_step)
92
+ #
93
+ # else:
94
+ # # sigmoid
95
+ # betas = np.linspace(-6, 6, self.total_step)
96
+ # betas = 1 / (1 + np.exp(betas)) * (afa_max - afa_min) + afa_min
97
+ betas = self.get_named_beta_schedule(schedule_name=beta_schedule_name,
98
+ num_diffusion_timesteps=total_step)
99
+
100
+ afas = 1 - betas
101
+ afas_cumprod = np.cumprod(afas)
102
+ # afas_cumprod = np.concatenate((np.array([1]), afas_cumprod[:-1]), axis=0)
103
+ return afas_cumprod, betas
104
+
105
+ # 重全噪声开始
106
+ @torch.no_grad()
107
+ def sample_loop(self, model, vae_middle_c, batch_size, step, eta, shape=(32, 32)):
108
+ pass
109
+
110
+ def apple_noise(self, data, step):
111
+ """
112
+ 添加噪声,返回xt和噪声
113
+ :param data: 数据,潜空间
114
+ :param step: 选择的步数
115
+ :return:
116
+ """
117
+ data = data.to(self.device)
118
+
119
+ noise = torch.randn(size=data.shape).to(self.device)
120
+ afa_bar_t = self.afas_cumprod[step - 1]
121
+ x_t = torch.sqrt(afa_bar_t) * data + torch.sqrt(1 - afa_bar_t) * noise
122
+ return x_t
123
+
124
+ # 图生图
125
+ @torch.no_grad()
126
+ def sample_loop_img2img(self, input_img, model, vae_middle_c, batch_size, step, eta):
127
+ pass
128
+
129
+ @torch.no_grad()
130
+ def decode_img(self, vae, x0):
131
+ x0 = vae.decoder(x0)
132
+ res = x0.cpu().numpy()
133
+ if vae.middle_c == 8:
134
+ res = (res + 1) * 127.5
135
+ else:
136
+ res = res * 255
137
+ res = np.transpose(res, [0, 2, 3, 1]) # RGB
138
+ res = np.clip(res, 0, 255)
139
+ res = np.array(res, dtype=np.uint8)
140
+ return res
141
+
142
+ @torch.no_grad()
143
+ def encode_img(self, vae, x0):
144
+ mu, _ = vae.encoder(x0)
145
+ return mu
146
+
147
+
148
+ class DDIMSampler(Sampler):
149
+ def __init__(self, device, normal_t):
150
+ super(DDIMSampler, self).__init__(device, normal_t)
151
+
152
+ # self.afas_cumprod, self.betas = self.get_afa_bars("scaled_linear",
153
+ # self.total_step) # cosine,linear,scaled_linear
154
+ # self.afas_cumprod = torch.Tensor(self.afas_cumprod).to(self.device)
155
+ # self.betas = torch.Tensor(self.betas).to(self.device)
156
+
157
+ @torch.no_grad()
158
+ def sample(self, model, x, t, next_t, eta):
159
+ """
160
+
161
+ :param model:
162
+ :param x:
163
+ :param t: 属于[1,1000]
164
+ :return:
165
+ """
166
+ t_ = torch.ones((x.shape[0], 1)) * t
167
+ t_ = t_.to(self.device)
168
+ if self.normal_t:
169
+ t_ = t_ / self.total_step
170
+ epsilon = model(x, t_)
171
+ # 把t转成index
172
+ t = int(t - 1)
173
+ next_t = int(next_t - 1)
174
+ if t > 1:
175
+ # pred_x0=(x-sqrt(1-afa_t_bar)ε)/(sqrt(afa_t_bar))
176
+ prede_x0 = (x - torch.sqrt(1 - self.afas_cumprod[t]) * epsilon) / torch.sqrt(self.afas_cumprod[t])
177
+ x_t_1 = torch.sqrt(self.afas_cumprod[next_t]) * prede_x0
178
+ delta = eta * torch.sqrt((1 - self.afas_cumprod[next_t]) / (1 - self.afas_cumprod[t])) * torch.sqrt(
179
+ 1 - self.afas_cumprod[t] / self.afas_cumprod[next_t])
180
+ x_t_1 = x_t_1 + torch.sqrt(1 - self.afas_cumprod[next_t] - delta ** 2) * epsilon
181
+ x_t_1 = delta * random_clip(torch.randn_like(x)) + x_t_1
182
+ else:
183
+ coeff = self.betas[t] / (torch.sqrt(1 - self.afas_cumprod[t])) # + 1e-5
184
+ x_t_1 = (1 / torch.sqrt(1 - self.betas[t])) * (x - coeff * epsilon)
185
+
186
+ return x_t_1
187
+
188
+ @torch.no_grad()
189
+ def sample_loop(self, model, vae_middle_c, batch_size, step, eta, shape=(32, 32)):
190
+ if step < 1000 and False:
191
+ # 分两端均匀取子集
192
+ # 1k步中的前35%用指定推理步数的50%
193
+ big_steps = self.total_step * (1 - 0.4)
194
+ big_ = int(step * 0.6)
195
+ steps = np.linspace(self.total_step, big_steps, big_)
196
+ steps = np.concatenate([steps, np.linspace(big_steps + int(steps[1] - steps[0]), 1, step - big_)],
197
+ axis=0)
198
+ else:
199
+ # 均匀取子集
200
+ steps = np.linspace(self.total_step, 1, step)
201
+ steps = np.floor(steps)
202
+ steps = np.concatenate((steps, steps[-1:]), axis=0)
203
+
204
+ x_t = random_clip(torch.randn((batch_size, vae_middle_c, *shape))).to(self.device) # 32, 32
205
+ for i in range(len(steps) - 1):
206
+ x_t = self.sample(model, x_t, steps[i], steps[i + 1], eta)
207
+
208
+ yield x_t
209
+
210
+ @torch.no_grad()
211
+ def sample_loop_img2img(self, input_img_latents, noise_steps, model, vae_middle_c, batch_size, step, eta):
212
+ noised_latents = self.apple_noise(input_img_latents, noise_steps) # (1,4,32,32)
213
+ step = min(noise_steps, step)
214
+ if step < 1000 and False:
215
+ # 分两端均匀取子集
216
+ # 1k步中的前20%用指定推理步数的50%
217
+ big_steps = noise_steps * (1 - 0.3)
218
+ big_ = int(step * 0.5)
219
+ steps = np.linspace(noise_steps, big_steps, big_)
220
+ steps = np.concatenate([steps, np.linspace(big_steps + int(steps[1] - steps[0]), 1, step - big_)],
221
+ axis=0)
222
+ else:
223
+ # 均匀取子集
224
+ steps = np.linspace(noise_steps, 1, step)
225
+ steps = np.floor(steps)
226
+ steps = np.concatenate((steps, steps[-1:]), axis=0)
227
+
228
+ x_t = torch.tile(noised_latents, (batch_size, 1, 1, 1)).to(self.device) # 32, 32
229
+ for i in trange(len(steps) - 1):
230
+ x_t = self.sample(model, x_t, steps[i], steps[i + 1], eta)
231
+
232
+ yield x_t
233
+
234
+
235
+ class EulerDpmppSampler(Sampler):
236
+ def __init__(self, device, normal_t):
237
+ super(EulerDpmppSampler, self).__init__(device, normal_t)
238
+ self.sample_fun = self.sample_dpmpp_2m
239
+
240
+ @staticmethod
241
+ def append_zero(x):
242
+ return torch.cat([x, x.new_zeros([1])])
243
+
244
+ # 4e-5 0.99
245
+ @staticmethod
246
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cuda'):
247
+ """Constructs the noise schedule of Karras et al. (2022)."""
248
+ ramp = torch.linspace(0, 1, n)
249
+ min_inv_rho = sigma_min ** (1 / rho)
250
+ max_inv_rho = sigma_max ** (1 / rho)
251
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
252
+ return EulerDpmppSampler.append_zero(sigmas).to(device)
253
+
254
+ @staticmethod
255
+ def default_noise_sampler(x):
256
+ return lambda sigma, sigma_next: torch.randn_like(x)
257
+
258
+ @staticmethod
259
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
260
+ """Calculates the noise level (sigma_down) to step down to and the amount
261
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
262
+ if not eta:
263
+ return sigma_to, 0.
264
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
265
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
266
+ return sigma_down, sigma_up
267
+
268
+ @staticmethod
269
+ def append_dims(x, target_dims):
270
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
271
+ dims_to_append = target_dims - x.ndim
272
+ if dims_to_append < 0:
273
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
274
+ return x[(...,) + (None,) * dims_to_append]
275
+
276
+ @staticmethod
277
+ def to_d(x, sigma, denoised):
278
+ """Converts a denoiser output to a Karras ODE derivative."""
279
+ return (x - denoised) / EulerDpmppSampler.append_dims(sigma, x.ndim)
280
+
281
+ @staticmethod
282
+ def to_denoised(x, sigma, d):
283
+ return x - d * EulerDpmppSampler.append_dims(sigma, x.ndim)
284
+
285
+ @torch.no_grad()
286
+ def sample_euler_ancestral(self, model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1.,
287
+ noise_sampler=None):
288
+ """Ancestral sampling with Euler method steps."""
289
+ extra_args = {} if extra_args is None else extra_args
290
+ noise_sampler = EulerDpmppSampler.default_noise_sampler(x) if noise_sampler is None else noise_sampler
291
+ s_in = x.new_ones([x.shape[0], 1])
292
+ for i in trange(len(sigmas) - 1, disable=disable):
293
+ t = sigmas[i] * (1 - 1 / self.total_step) + 1 / self.total_step
294
+ t = torch.floor(t * self.total_step) # 不归一化t需要输入整数
295
+
296
+ afa_bar_t = self.afas_cumprod[int(t) - 1] # 获得加噪用的afa bar
297
+ if self.normal_t:
298
+ t = t / self.total_step
299
+
300
+ t = t * s_in
301
+ output = model(x, t, **extra_args)
302
+ denoised = (x - torch.sqrt(1 - afa_bar_t) * output) / torch.sqrt(afa_bar_t)
303
+
304
+ sigma_down, sigma_up = self.get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
305
+ if callback is not None:
306
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
307
+ d = self.to_d(x, sigmas[i], denoised)
308
+ # d = denoised
309
+ # Euler method
310
+ dt = sigma_down - sigmas[i]
311
+ x = x + d * dt
312
+ if sigmas[i + 1] > 0:
313
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
314
+ yield x
315
+ # return x
316
+
317
+ @torch.no_grad()
318
+ def sample_dpmpp_2m(self, model, x, sigmas, extra_args=None, callback=None, disable=None):
319
+ """DPM-Solver++(2M)."""
320
+ extra_args = {} if extra_args is None else extra_args
321
+ s_in = x.new_ones([x.shape[0], 1])
322
+ sigma_fn = lambda t: t.neg().exp()
323
+ t_fn = lambda sigma: sigma.log().neg()
324
+ old_denoised = None
325
+
326
+ for i in trange(len(sigmas) - 1, disable=disable):
327
+ t = sigmas[i] * (1 - 1 / self.total_step) + 1 / self.total_step
328
+ t = torch.floor(t * self.total_step) # 不归一化t需要输入整数
329
+
330
+ afa_bar_t = self.afas_cumprod[int(t) - 1] # 获得加噪用的afa bar
331
+ if self.normal_t:
332
+ t = t / self.total_step
333
+
334
+ t = t * s_in
335
+ output = model(x, t, **extra_args)
336
+ denoised = (x - torch.sqrt(1 - afa_bar_t) * output) / torch.sqrt(afa_bar_t)
337
+
338
+ if callback is not None:
339
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
340
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
341
+ h = t_next - t
342
+ if old_denoised is None or sigmas[i + 1] == 0:
343
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
344
+ else:
345
+ h_last = t - t_fn(sigmas[i - 1])
346
+ r = h_last / h
347
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
348
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
349
+ old_denoised = denoised
350
+ yield x
351
+
352
+ def switch_sampler(self, sampler_name):
353
+ if sampler_name == "euler a":
354
+ self.sample_fun = self.sample_euler_ancestral
355
+ elif sampler_name == "dpmpp 2m":
356
+ self.sample_fun = self.sample_dpmpp_2m
357
+ else:
358
+ self.sample_fun = self.sample_euler_ancestral
359
+
360
+ def sample_loop(self, model, vae_middle_c, batch_size, step, eta, shape=(32, 32)):
361
+ x = torch.randn((batch_size, vae_middle_c, 32, 32)).to(device)
362
+ sigmas = self.get_sigmas_karras(step, 1e-5, 0.999, device=device)
363
+ # sigmas = self.get_named_beta_schedule("scaled_linear", step)
364
+
365
+ looper = self.sample_fun(unet, x, sigmas)
366
+ for _ in trange(len(sigmas) - 1):
367
+ x_t = next(looper)
368
+ yield x_t
369
+
370
+
371
+ class PretrainVae:
372
+ def __init__(self, device):
373
+ from diffusers import AutoencoderKL, DiffusionPipeline
374
+ self.vae = AutoencoderKL.from_pretrained("gsdf/Counterfeit-V2.5", # segmind/small-sd
375
+ subfolder="vae",
376
+ cache_dir="./vae/pretrain_vae").to(device)
377
+ self.vae.requires_grad_(False)
378
+ self.middle_c = 4
379
+ self.vae_scaleing = 0.18215
380
+
381
+ def encoder(self, x):
382
+ latents = self.vae.encode(x)
383
+ latents = latents.latent_dist
384
+ mean = latents.mean * self.vae_scaleing
385
+ var = latents.var * self.vae_scaleing
386
+ return mean, var
387
+
388
+ def decoder(self, latents):
389
+ latents = latents / self.vae_scaleing
390
+ output = self.vae.decode(latents).sample
391
+ return output
392
+
393
+ # 释放encoder
394
+ def res_encoder(self):
395
+ del self.vae.encoder
396
+ torch.cuda.empty_cache()
397
+
398
+
399
+ # ================================================================
400
+
401
+ def merge_images(images: np.ndarray):
402
+ """
403
+ 合并图像
404
+ :param images: 图像数组
405
+ :return: 合并后的图像数组
406
+ """
407
+ n, h, w, c = images.shape
408
+ nn = int(np.ceil(n ** 0.5))
409
+ merged_image = np.zeros((h * nn, w * nn, 3), dtype=images.dtype)
410
+ for i in range(n):
411
+ row = i // nn
412
+ col = i % nn
413
+ merged_image[row * h:(row + 1) * h, col * w:(col + 1) * w, :] = images[i]
414
+
415
+ merged_image = np.clip(merged_image, 0, 255)
416
+ merged_image = np.array(merged_image, dtype=np.uint8)
417
+ return merged_image
418
+
419
+
420
+ def get_models(device):
421
+ def modelLoad(model, model_path, data_parallel=False):
422
+ model.load_state_dict(torch.load(model_path), strict=True)
423
+
424
+ if data_parallel:
425
+ model = torch.nn.DataParallel(model)
426
+ return model
427
+
428
+ from net.UNet import UNet
429
+ config = {
430
+ # 模型结构相关
431
+ "en_out_c": (256, 256, 256, 320, 320, 320, 576, 576, 576, 704, 704, 704),
432
+ "en_down": (0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
433
+ "en_skip": (0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1),
434
+ "en_att_heads": (8, 8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8),
435
+ "de_out_c": (704, 576, 576, 576, 320, 320, 320, 256, 256, 256, 256),
436
+ "de_up": ("none", "subpix", "none", "none", "subpix", "none", "none", "subpix", "none", "none", "none"),
437
+ "de_skip": (1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
438
+ "de_att_heads": (8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8), # skip的地方不做self-attention
439
+ "t_out_c": 256,
440
+ "vae_c": 4,
441
+ "block_deep": 3,
442
+ "use_pretrain_vae": True,
443
+
444
+ "normal_t": True,
445
+
446
+ "model_save_path": "./weight",
447
+ "model_name": "unet",
448
+ "model_tail": "ema",
449
+ }
450
+ print("加载模型...")
451
+ unet = UNet(config["en_out_c"], config["en_down"], config["en_skip"], config["en_att_heads"],
452
+ config["de_out_c"], config["de_up"], config["de_skip"], config["de_att_heads"],
453
+ config["t_out_c"], config["vae_c"], config["block_deep"]).to(device)
454
+ unet = modelLoad(unet, os.path.join(config["model_save_path"],
455
+ f"{config['model_name']}_{config['model_tail']}.pth"))
456
+
457
+ vae = PretrainVae(device)
458
+ print("加载完成")
459
+ return unet, vae, config["normal_t"]
460
+
461
+
462
+ def init_webui(unet, vae, normal_t):
463
+ # 定义回调函数
464
+ def process_image(input_image_value, noise_step, step_value, batch_size, sampler_name, img_size,
465
+ progress=gr.Progress()):
466
+ progress(0, desc="开始...")
467
+
468
+ noise_step = float(noise_step)
469
+ step_value = int(step_value)
470
+ batch_size = int(batch_size)
471
+ img_size = int(img_size) // 8
472
+ img_size = (img_size, img_size)
473
+
474
+ if sampler_name == "DDIM":
475
+ sampler = DDIMSampler(device, normal_t)
476
+ elif sampler_name == "euler a" or sampler_name == "dpmpp 2m":
477
+ sampler = EulerDpmppSampler(device, normal_t)
478
+ sampler.switch_sampler(sampler_name)
479
+ else:
480
+ raise ValueError(f"Unknow sampler_name: {sampler_name}")
481
+ if input_image_value is None:
482
+ looper = sampler.sample_loop(unet, vae.middle_c, batch_size, step_value, shape=img_size, eta=1.)
483
+ else:
484
+ input_image_value = Image.fromarray(input_image_value).resize(img_size, Image.ANTIALIAS)
485
+ input_image_value = np.array(input_image_value, dtype=np.float32) / 255.
486
+ input_image_value = np.transpose(input_image_value, (2, 0, 1))
487
+ input_image_value = torch.Tensor([input_image_value]).to(device)
488
+ input_img_latents = sampler.encode_img(vae, input_image_value)
489
+ looper = sampler.sample_loop_img2img(input_img_latents,
490
+ int(noise_step * sampler.total_step),
491
+ unet,
492
+ vae.middle_c,
493
+ batch_size,
494
+ step_value,
495
+ eta=1.)
496
+ for i in progress.tqdm(range(1, step_value + 1)):
497
+ output = next(looper)
498
+
499
+ output = sampler.decode_img(vae, output)
500
+ output = np.clip(output, 0, 255)
501
+ marge_img = merge_images(output)
502
+
503
+ output = [marge_img] + list(output)
504
+
505
+ return output
506
+
507
+ with gr.Blocks(title="图片处理") as iface:
508
+ with gr.Column():
509
+ with gr.Row():
510
+ with gr.Column():
511
+ # 创建输入组件
512
+ input_image = gr.Image(label="输入图片")
513
+ # 加噪程度
514
+ noise_step = gr.Slider(minimum=0.05, maximum=1, value=0.6, label="加噪程度", step=0.01)
515
+ with gr.Column():
516
+ # 选择sampler
517
+ sampler_name = gr.Dropdown(["DDIM"], label="sampler", value="DDIM") # , "euler a", "dpmpp 2m"
518
+ # 创建滑动条组件
519
+ step = gr.Slider(minimum=1, maximum=1000, value=400, label="步长", step=1)
520
+ batch_size = gr.Slider(minimum=1, maximum=4, label="batch size", step=1)
521
+ img_size = gr.Slider(minimum=256, maximum=512, value=256, label="img size", step=64)
522
+ # 创建开始按钮组件
523
+ start_button = gr.Button(label="开始")
524
+ # 创建输出组件
525
+ output_images = gr.Gallery(show_label=False, height=400, columns=5)
526
+
527
+ start_button.click(process_image, [input_image, noise_step, step, batch_size, sampler_name, img_size],
528
+ [output_images])
529
+
530
+ return iface
531
+
532
+
533
+ if __name__ == '__main__':
534
+ device = "cuda"
535
+ unet, vae, normal_t = get_models(device)
536
+
537
+
538
+ def run_with_ui(unet, vae, normal_t):
539
+ # 创建界面
540
+ iface = init_webui(unet, vae, normal_t)
541
+
542
+ # 运行界面
543
+ iface.queue().launch() #
544
+
545
+
546
+ run_with_ui(unet, vae, normal_t)
net/UNet.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ att_uncontrol9_adam以及之前的都是用这个
3
+
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import math
10
+
11
+
12
+ class SubPixelConv(nn.Module):
13
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, scale_factor=2):
14
+ super(SubPixelConv, self).__init__()
15
+ self.conv = nn.Conv2d(in_channels, out_channels * scale_factor ** 2, kernel_size, stride,
16
+ padding=kernel_size // 2)
17
+ self.pixel_shuffle = nn.PixelShuffle(scale_factor)
18
+
19
+ def forward(self, x):
20
+ x = self.conv(x)
21
+ x = self.pixel_shuffle(x)
22
+ return x
23
+
24
+
25
+ class Swish(nn.Module):
26
+ def __init__(self):
27
+ super(Swish, self).__init__()
28
+
29
+ def forward(self, x):
30
+ # swish
31
+ return x * torch.sigmoid(x)
32
+
33
+
34
+ def zero_module(module):
35
+ """
36
+ Zero out the parameters of a module and return it.
37
+ """
38
+ for p in module.parameters():
39
+ p.detach().zero_()
40
+ return module
41
+
42
+
43
+ class AttentionBlock(nn.Module):
44
+ """
45
+ An attention block that allows spatial positions to attend to each other.
46
+
47
+ Originally ported from here, but adapted to the N-d case.
48
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
49
+ """
50
+
51
+ def __init__(self, channels, num_heads=-1, use_checkpoint=False):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.num_heads = num_heads if num_heads != -1 else min(channels // 32, 8)
55
+ self.use_checkpoint = use_checkpoint
56
+
57
+ self.norm = nn.GroupNorm(16, channels, eps=1e-6)
58
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
59
+ self.attention = QKVAttention()
60
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
61
+
62
+ def forward(self, x):
63
+ b, c, *spatial = x.shape
64
+ x = x.reshape(b, c, -1)
65
+ qkv = self.qkv(self.norm(x))
66
+ qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
67
+ h = self.attention(qkv)
68
+ h = h.reshape(b, -1, h.shape[-1])
69
+ h = self.proj_out(h)
70
+ return (x + h).reshape(b, c, *spatial)
71
+
72
+
73
+ class QKVAttention(nn.Module):
74
+ """
75
+ A module which performs QKV attention.
76
+ """
77
+
78
+ def forward(self, qkv):
79
+ """
80
+ Apply QKV attention.
81
+
82
+ :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
83
+ :return: an [N x C x T] tensor after attention.
84
+ """
85
+ ch = qkv.shape[1] // 3
86
+ q, k, v = torch.split(qkv, ch, dim=1)
87
+ scale = 1 / math.sqrt(math.sqrt(ch))
88
+ weight = torch.einsum(
89
+ "bct,bcs->bts", q * scale, k * scale
90
+ ) # More stable with f16 than dividing afterwards
91
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
92
+ return torch.einsum("bts,bcs->bct", weight, v)
93
+
94
+ @staticmethod
95
+ def count_flops(model, _x, y):
96
+ """
97
+ A counter for the `thop` package to count the operations in an
98
+ attention operation.
99
+
100
+ Meant to be used like:
101
+
102
+ macs, params = thop.profile(
103
+ model,
104
+ inputs=(inputs, timestamps),
105
+ custom_ops={QKVAttention: QKVAttention.count_flops},
106
+ )
107
+
108
+ """
109
+ b, c, *spatial = y[0].shape
110
+ num_spatial = int(np.prod(spatial))
111
+ # We perform two matmuls with the same number of ops.
112
+ # The first computes the weight matrix, the second computes
113
+ # the combination of the value vectors.
114
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
115
+ model.total_ops += torch.DoubleTensor([matmul_ops])
116
+
117
+
118
+ # ====================================================================
119
+
120
+ class TEncoder(nn.Module):
121
+ def __init__(self, out_c=256, scale=30.):
122
+ super(TEncoder, self).__init__()
123
+ # 随机映射
124
+ self.out_c = out_c
125
+ self.W = nn.Parameter(torch.randn(out_c // 2) * scale, requires_grad=False)
126
+ self.linear = nn.Sequential(nn.Linear(out_c, out_c),
127
+ Swish(),
128
+ nn.Linear(out_c, out_c),
129
+ )
130
+
131
+ def timestep_embedding(self, timesteps, max_period=10000):
132
+ """
133
+ Create sinusoidal timestep embeddings.
134
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
135
+ These may be fractional.
136
+ :param dim: the dimension of the output.
137
+ :param max_period: controls the minimum frequency of the embeddings.
138
+ :return: an [N x dim] Tensor of positional embeddings.
139
+ """
140
+ half = self.out_c // 2
141
+ freqs = torch.exp(
142
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
143
+ ).to(device=timesteps.device)
144
+ args = timesteps[:, None].float() * freqs[None]
145
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
146
+ if self.out_c % 2:
147
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
148
+ return embedding
149
+
150
+ def forward(self, t):
151
+ # t_proj = t * self.W[None, :] * 2 * np.pi
152
+ # t_proj = torch.cat((torch.sin(t_proj), torch.cos(t_proj)), dim=-1)
153
+ t_proj = self.timestep_embedding(t)[:, 0, :]
154
+ encoded_t = self.linear(t_proj)
155
+ return encoded_t
156
+
157
+
158
+ class EncoderBlock(nn.Module):
159
+ def __init__(self, in_c, out_c, kernel_size, stride, t_in_c, att_num_head=-1, block_deep=4):
160
+ super(EncoderBlock, self).__init__()
161
+ self.in_c = in_c
162
+ self.out_c = out_c
163
+ self.stride = stride
164
+ self.model_list_len = block_deep # 一个block有多少次卷积
165
+
166
+ padding = kernel_size // 2
167
+ self.model_list = nn.ModuleList()
168
+ self.model_list.append(nn.Sequential(
169
+ nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding),
170
+ nn.GroupNorm(16, out_c, eps=1e-6),
171
+ Swish()))
172
+ if att_num_head != 0: # stride == 1
173
+ self.att_block = AttentionBlock(out_c, num_heads=att_num_head)
174
+ else:
175
+ self.att_block = nn.Identity()
176
+ for _ in range(self.model_list_len - 2): # -2是减一头一尾
177
+ self.model_list.append(
178
+ nn.Sequential(
179
+ nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
180
+ padding=padding),
181
+ nn.GroupNorm(16, out_c, eps=1e-6),
182
+ Swish(),
183
+ ))
184
+ self.model_list.append(
185
+ nn.Sequential(
186
+ nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
187
+ padding=padding),
188
+ nn.GroupNorm(16, out_c, eps=1e-6),
189
+ ))
190
+
191
+ # 编码时间t
192
+ self.encode_t = nn.ModuleList(
193
+ [nn.Linear(t_in_c, out_c) for _ in range(len(self.model_list) - 1)])
194
+
195
+ if self.in_c != self.out_c or self.stride != 1:
196
+ self.conv_skip = nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0)
197
+ else:
198
+ self.conv_skip = nn.Identity()
199
+ self.act_skip = Swish()
200
+
201
+ def forward(self, x, t):
202
+ skip = self.conv_skip(x)
203
+
204
+ for i, layer in enumerate(self.model_list):
205
+ x = layer(x)
206
+ if i == 0:
207
+ x = self.att_block(x)
208
+ if i < self.model_list_len - 1:
209
+ t_ = self.encode_t[i](t)
210
+ # t_ = torch.tile(t[:, :, None, None], dims=[1, 1, x.shape[2], x.shape[3]])
211
+ t_ = t_[:, :, None, None]
212
+ x = x + t_
213
+
214
+ return self.act_skip(x + skip)
215
+
216
+
217
+ class DecoderBlock(nn.Module):
218
+ def __init__(self, in_c, out_c, kernel_size, upsample="none", t_in_c=256, att_num_head=-1, block_deep=4):
219
+ super(DecoderBlock, self).__init__()
220
+ self.in_c = in_c
221
+ self.out_c = out_c
222
+ self.model_list_len = block_deep # 一个block有多少次卷积
223
+
224
+ self.model_list = nn.ModuleList()
225
+
226
+ if upsample == "subpix":
227
+ self.model_list.append(nn.Sequential(
228
+ SubPixelConv(in_c, out_c, kernel_size=3),
229
+ nn.GroupNorm(16, out_c, eps=1e-6),
230
+ Swish()
231
+ ))
232
+
233
+ self.upsample = SubPixelConv(in_c, in_c, kernel_size=3)
234
+ elif upsample == "convt":
235
+ self.model_list.append(nn.Sequential(
236
+ nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1),
237
+ nn.GroupNorm(16, out_c, eps=1e-6),
238
+ Swish()
239
+ ))
240
+
241
+ self.upsample = nn.ConvTranspose2d(in_c, in_c, kernel_size=4, stride=2, padding=1)
242
+ else:
243
+ self.model_list.append(nn.Sequential(
244
+ nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=1,
245
+ padding=kernel_size // 2),
246
+ nn.GroupNorm(16, out_c, eps=1e-6),
247
+ Swish()
248
+ ))
249
+ self.upsample = nn.Identity()
250
+
251
+ if att_num_head != 0: # upsample != "none"
252
+ self.att_block = AttentionBlock(out_c, num_heads=att_num_head)
253
+ else:
254
+ self.att_block = nn.Identity()
255
+
256
+ for _ in range(self.model_list_len - 2):
257
+ self.model_list.append(nn.Sequential(nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
258
+ padding=kernel_size // 2),
259
+ nn.GroupNorm(16, out_c, eps=1e-6),
260
+ Swish()))
261
+
262
+ self.model_list.append(nn.Sequential(nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
263
+ padding=kernel_size // 2),
264
+ nn.GroupNorm(16, out_c, eps=1e-6)))
265
+
266
+ # 编码时间t
267
+ self.encode_t = nn.ModuleList([nn.Linear(t_in_c, out_c) for _ in range(len(self.model_list) - 1)])
268
+
269
+ self.conv_skip = nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0)
270
+ self.act_skip = Swish()
271
+
272
+ def forward(self, x, t):
273
+ skip = self.upsample(x)
274
+ skip = self.conv_skip(skip)
275
+
276
+ for i, layer in enumerate(self.model_list):
277
+ x = layer(x)
278
+ if i == 0:
279
+ x = self.att_block(x)
280
+ if i < self.model_list_len - 1:
281
+ t_ = self.encode_t[i](t)
282
+ # t_ = torch.tile(t[:, :, None, None], dims=[1, 1, x.shape[2], x.shape[3]])
283
+ t_ = t_[:, :, None, None]
284
+ x = x + t_
285
+
286
+ return self.act_skip(x + skip)
287
+
288
+
289
+ class Encoder(nn.Module):
290
+ def __init__(self,
291
+ model_in_c=8,
292
+ out_cs=(64, 64, 128, 128, 256, 256, 512, 512),
293
+ down_sample=(0, 0, 1, 0, 1, 0, 1, 0),
294
+ skip_out=(0, 1, 0, 1, 0, 1, 0, 1),
295
+ att_num_heads=(-1, -1, -1, -1, -1, -1, -1, -1),
296
+ t_in_c=256,
297
+ block_deep=4):
298
+ """
299
+
300
+ :param out_cs: 每一个块输出的尺寸
301
+ :param down_sample: 是否下采样
302
+ :param skip_out: unet的条连
303
+ """
304
+ super(Encoder, self).__init__()
305
+
306
+ self.skip_out = skip_out
307
+
308
+ self.model_list = nn.ModuleList()
309
+ for i, (out_c, down, att_num_head) in enumerate(zip(out_cs, down_sample, att_num_heads)):
310
+ in_c = model_in_c if i == 0 else out_cs[i - 1]
311
+ self.model_list.append(
312
+ EncoderBlock(in_c, out_cs[i], kernel_size=3, stride=down + 1, t_in_c=t_in_c,
313
+ att_num_head=att_num_head, block_deep=block_deep))
314
+
315
+ def forward(self, x, t):
316
+ res_x = []
317
+ for i, layer in enumerate(self.model_list):
318
+ x = layer(x, t)
319
+ if self.skip_out[i] == 1:
320
+ res_x.append(x)
321
+ return res_x
322
+
323
+
324
+ class Decoder(nn.Module):
325
+ def __init__(self,
326
+ in_c,
327
+ model_out_c=8,
328
+ out_cs=(512, 256, 256, 128, 128, 64, 64, 32),
329
+ up_sample=("none", "convt", "none", "subpix", "none", "subpix", "none", "none"),
330
+ skip_out=(1, 0, 1, 0, 1, 0, 1, 0),
331
+ att_num_heads=(-1, -1, -1, -1, -1, -1, -1, -1),
332
+ t_in_c=256,
333
+ block_deep=4):
334
+ """
335
+
336
+ :param out_cs: 每一个块输出的尺寸
337
+ :param up_sample: 上采样方法,none是不进行上采样
338
+ :param skip_out: unet的跳连
339
+ """
340
+ super(Decoder, self).__init__()
341
+
342
+ self.skip_out = skip_out
343
+ self.model_list = nn.ModuleList()
344
+ for i, (out_c, up, att_num_head) in enumerate(zip(out_cs, up_sample, att_num_heads)):
345
+ if self.skip_out[i] == 1 and i > 0:
346
+ in_c *= 2
347
+ self.model_list.append(
348
+ DecoderBlock(in_c, out_cs[i], kernel_size=3, upsample=up, t_in_c=t_in_c,
349
+ att_num_head=att_num_head, block_deep=block_deep))
350
+ in_c = out_cs[i]
351
+
352
+ self.Conv1 = nn.Conv2d(out_cs[-1], model_out_c, kernel_size=1, stride=1, padding=0)
353
+
354
+ def forward(self, x, t):
355
+ x_list = x
356
+ # print([xx.shape for xx in x_list])
357
+ x = None
358
+ for i, layer in enumerate(self.model_list):
359
+ if self.skip_out[i] == 1:
360
+ # print("skip_x:", x_list[-1].shape)
361
+ if i == 0:
362
+ x = x_list.pop()
363
+ else:
364
+ x = torch.cat([x, x_list.pop()], dim=1)
365
+ # print("x:", x.shape)
366
+ x = layer(x, t)
367
+
368
+ x = self.Conv1(x)
369
+ return x
370
+
371
+
372
+ class UNet(nn.Module):
373
+ def __init__(self,
374
+ en_out_c,
375
+ en_down,
376
+ en_skip,
377
+ en_att_heads,
378
+ de_out_c,
379
+ de_up,
380
+ de_skip,
381
+ de_att_heads,
382
+ t_out_c,
383
+ vae_c=8,
384
+ block_deep=4):
385
+ """
386
+
387
+ :param en_out_c: encoder参数
388
+ :param en_down:
389
+ :param en_skip:
390
+ :param de_out_c: decoder参数
391
+ :param de_up:
392
+ :param de_skip:
393
+ """
394
+ super(UNet, self).__init__()
395
+
396
+ self.encoder = Encoder(model_in_c=vae_c,
397
+ out_cs=en_out_c,
398
+ down_sample=en_down,
399
+ skip_out=en_skip,
400
+ att_num_heads=en_att_heads,
401
+ t_in_c=t_out_c,
402
+ block_deep=block_deep)
403
+ self.decoder = Decoder(in_c=en_out_c[-1],
404
+ model_out_c=vae_c,
405
+ out_cs=de_out_c,
406
+ up_sample=de_up,
407
+ skip_out=de_skip,
408
+ att_num_heads=de_att_heads,
409
+ t_in_c=t_out_c,
410
+ block_deep=block_deep)
411
+ self.t_encoder = TEncoder(t_out_c)
412
+
413
+ def forward(self, x, t):
414
+ t = self.t_encoder(t)
415
+ # print("encoded_t:", torch.mean(t), torch.std(t))
416
+ # print("t:", t.shape)
417
+ encoder_out = self.encoder(x, t)
418
+ # print("encode:")
419
+ # for e in encoder_out:
420
+ # print(e.shape)
421
+ decoder_out = self.decoder(encoder_out, t)
422
+ # print("decoder:")
423
+ # print(decoder_out.shape)
424
+ return decoder_out
425
+
426
+
427
+ if __name__ == '__main__':
428
+ import cv2, os
429
+
430
+
431
+ def modelSave(model, save_path, save_name):
432
+ if not os.path.exists(save_path):
433
+ os.mkdir(save_path)
434
+ torch.save(model.state_dict(), os.path.join(save_path, save_name))
435
+
436
+
437
+ def merge_images(images: np.ndarray):
438
+ """
439
+ 合并图像
440
+ :param images: 图像数组
441
+ :return: 合并后的图像数组
442
+ """
443
+ n, h, w, c = images.shape
444
+ nn = int(np.ceil(n ** 0.5))
445
+ merged_image = np.zeros((h * nn, w * nn, 3), dtype=images.dtype)
446
+ for i in range(n):
447
+ row = i // nn
448
+ col = i % nn
449
+ merged_image[row * h:(row + 1) * h, col * w:(col + 1) * w, :] = images[i]
450
+
451
+ merged_image = np.clip(merged_image, 0, 255)
452
+ merged_image = np.array(merged_image, dtype=np.uint8)
453
+ return merged_image
454
+
455
+
456
+ # 320,448,576,832
457
+ config = { # 模型结构相关
458
+ "en_out_c": (256, 256, 256, 320, 320, 320, 576, 576, 576, 704, 704, 704),
459
+ "en_down": (0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
460
+ "en_skip": (0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1),
461
+ "en_att_heads": (8, 8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8),
462
+ "de_out_c": (704, 576, 576, 576, 320, 320, 320, 256, 256, 256, 256),
463
+ "de_up": ("none", "subpix", "none", "none", "subpix", "none", "none", "subpix", "none", "none", "none"),
464
+ "de_skip": (1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
465
+ "de_att_heads": (8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8), # skip的地方不做self-attention
466
+ "t_out_c": 256,
467
+ "vae_c": 4,
468
+ "block_deep": 3,
469
+ }
470
+ device = "cuda"
471
+ total_step = 1000
472
+
473
+ unet = UNet(config["en_out_c"], config["en_down"], config["en_skip"], config["en_att_heads"],
474
+ config["de_out_c"], config["de_up"], config["de_skip"], config["de_att_heads"],
475
+ config["t_out_c"], config["vae_c"], config["block_deep"]).to(device)
476
+
477
+ print("总参数", sum(i.numel() for i in unet.parameters()) / 10000, "单位:万")
478
+ print("encoder", sum(i.numel() for i in unet.encoder.parameters()) / 10000, "单位:万")
479
+ print("decoder", sum(i.numel() for i in unet.decoder.parameters()) / 10000, "单位:万")
480
+ print("t", sum(i.numel() for i in unet.t_encoder.parameters()) / 10000, "单位:万")
481
+
482
+ batch_size = 2
483
+ x = np.random.random((batch_size, config["vae_c"], 32, 32))
484
+ t = np.random.uniform(1, total_step + 0.9999, size=(batch_size, 1))
485
+ t = np.array(t, dtype=np.int16)
486
+ t = t / total_step
487
+
488
+ with torch.no_grad():
489
+ x = torch.Tensor(x).to(device)
490
+ t = torch.Tensor(t).to(device)
491
+ y = unet(x, t)
492
+ print(y.shape)
493
+
494
+ z = y[0].cpu().numpy()
495
+ # z = (z - np.mean(z)) / (np.max(z) - np.min(z))
496
+ z = np.clip(np.asarray((z + 1) * 127.5), 0, 255)
497
+ z = np.asarray(z, dtype=np.uint8)
498
+
499
+ z = [np.tile(z[ii, :, :, np.newaxis], (1, 1, 3)) for ii in range(z.shape[0])]
500
+ noise = merge_images(np.array(z))
501
+
502
+ noise = cv2.resize(noise, None, fx=2, fy=2)
503
+ cv2.imshow("noise", noise)
504
+ cv2.waitKey(0)
505
+
506
+ # modelSave(unet, "./", "test.pth")
507
+ # 导出为onnx格式
508
+ torch.onnx.export(
509
+ unet,
510
+ (x, t),
511
+ 'unet.onnx',
512
+ export_params=True,
513
+ opset_version=12,
514
+ )
515
+ import onnx
516
+
517
+ # 增加维度信息
518
+ model_file = 'unet.onnx'
519
+ onnx_model = onnx.load(model_file)
520
+ onnx.save(onnx.shape_inference.infer_shapes(onnx_model), model_file)
requirements.txt ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.3.0
2
+ addict==2.4.0
3
+ aiofiles==23.1.0
4
+ aiohttp==3.8.3
5
+ aiosignal==1.3.1
6
+ aliyun-python-sdk-core==2.13.36
7
+ aliyun-python-sdk-kms==2.16.0
8
+ altair==4.2.0
9
+ anyio==3.6.2
10
+ appdirs==1.4.4
11
+ asttokens==2.3.0
12
+ async-timeout==4.0.2
13
+ attrs==22.1.0
14
+ audioread==3.0.0
15
+ backcall==0.2.0
16
+ certifi==2022.12.7
17
+ cffi==1.15.1
18
+ charset-normalizer==2.1.1
19
+ chumpy==0.70
20
+ click==8.1.3
21
+ clip==1.0
22
+ colorama==0.4.6
23
+ commonmark==0.9.1
24
+ contourpy==1.0.6
25
+ cpm-kernels==1.0.11
26
+ crcmod==1.7
27
+ cryptography==39.0.2
28
+ cycler==0.11.0
29
+ Cython==0.29.32
30
+ datasets==2.8.0
31
+ decorator==5.1.1
32
+ decord==0.6.0
33
+ diffusers==0.20.1
34
+ dill==0.3.6
35
+ docker-pycreds==0.4.0
36
+ einops==0.6.0
37
+ entrypoints==0.4
38
+ exceptiongroup==1.1.3
39
+ executing==1.2.0
40
+ fastapi==0.88.0
41
+ ffmpy==0.3.0
42
+ filelock==3.8.2
43
+ Flask==2.0.2
44
+ Flask-Cors==3.0.10
45
+ fonttools==4.38.0
46
+ frozenlist==1.3.3
47
+ fsspec==2022.11.0
48
+ ftfy==6.1.1
49
+ gast==0.5.3
50
+ gitdb==4.0.10
51
+ GitPython==3.1.32
52
+ gradio==3.39.0
53
+ gradio_client==0.3.0
54
+ h11==0.14.0
55
+ httpcore==0.16.2
56
+ httpx==0.23.1
57
+ huggingface-hub==0.16.4
58
+ icetk==0.0.4
59
+ idna==3.4
60
+ importlib-metadata==5.2.0
61
+ ipython==8.15.0
62
+ itsdangerous==2.1.2
63
+ jedi==0.19.0
64
+ Jinja2==3.1.2
65
+ jmespath==0.10.0
66
+ joblib==1.2.0
67
+ json-tricks==3.16.1
68
+ jsonplus==0.8.0
69
+ jsonschema==4.17.3
70
+ kiwisolver==1.4.4
71
+ lazy_loader==0.1
72
+ librosa==0.10.0
73
+ linkify-it-py==1.0.3
74
+ lion-pytorch==0.1.2
75
+ llvmlite==0.39.1
76
+ loguru==0.6.0
77
+ Markdown==3.4.1
78
+ markdown-it-py==2.1.0
79
+ MarkupSafe==2.1.1
80
+ matplotlib==3.6.2
81
+ matplotlib-inline==0.1.6
82
+ mdit-py-plugins==0.3.3
83
+ mdurl==0.1.2
84
+ mediapipe==0.8.11
85
+ mmcv-full==1.7.0
86
+ mmdet==2.26.0
87
+ model-index==0.1.11
88
+ modelscope==1.3.2
89
+ mpmath==1.2.1
90
+ msgpack==1.0.4
91
+ multidict==6.0.3
92
+ multiprocess==0.70.14
93
+ munkres==1.1.4
94
+ networkx==3.0
95
+ numba==0.56.4
96
+ numpy==1.23.4
97
+ onnx==1.14.1
98
+ opencv-contrib-python==4.5.5.64
99
+ opencv-python==4.5.5.64
100
+ openmim==0.3.3
101
+ ordered-set==4.1.0
102
+ orjson==3.8.3
103
+ oss2==2.16.0
104
+ packaging==21.3
105
+ pandas==1.5.2
106
+ parso==0.8.3
107
+ pathtools==0.1.2
108
+ pickleshare==0.7.5
109
+ Pillow==9.2.0
110
+ pip==23.1.2
111
+ platformdirs==3.1.0
112
+ plotly==5.11.0
113
+ pooch==1.7.0
114
+ prodigyopt==1.0
115
+ prompt-toolkit==3.0.39
116
+ protobuf==4.24.2
117
+ psutil==5.9.5
118
+ pure-eval==0.2.2
119
+ pyarrow==11.0.0
120
+ pycocotools==2.0.6
121
+ pycparser==2.21
122
+ pycryptodome==3.16.0
123
+ pydantic==1.10.2
124
+ pydub==0.25.1
125
+ Pygments==2.13.0
126
+ pyparsing==3.0.9
127
+ pyrsistent==0.19.2
128
+ python-dateutil==2.8.2
129
+ python-multipart==0.0.5
130
+ pytorch-fid==0.3.0
131
+ pytz==2022.6
132
+ PyYAML==6.0
133
+ regex==2022.10.31
134
+ requests==2.28.1
135
+ responses==0.18.0
136
+ rfc3986==1.5.0
137
+ rich==12.6.0
138
+ safetensors==0.3.3
139
+ scikit-learn==1.2.1
140
+ scipy==1.9.3
141
+ semantic-version==2.10.0
142
+ sentencepiece==0.1.97
143
+ sentry-sdk==1.28.0
144
+ setproctitle==1.3.2
145
+ setuptools==65.5.0
146
+ simplejson==3.18.3
147
+ six==1.16.0
148
+ smmap==5.0.0
149
+ sniffio==1.3.0
150
+ sortedcontainers==2.4.0
151
+ soundfile==0.12.1
152
+ soxr==0.3.4
153
+ stack-data==0.6.2
154
+ starlette==0.22.0
155
+ sympy==1.11.1
156
+ tabulate==0.9.0
157
+ tenacity==8.1.0
158
+ terminaltables==3.1.10
159
+ threadpoolctl==3.1.0
160
+ timm==0.4.9
161
+ tokenizers==0.13.2
162
+ toolz==0.12.0
163
+ torch==2.0.0+cu117
164
+ torchaudio==2.0.1+cu117
165
+ torchinfo==1.7.1
166
+ torchvision==0.15.1+cu117
167
+ tqdm==4.64.1
168
+ traitlets==5.9.0
169
+ transformers==4.26.1
170
+ typing_extensions==4.4.0
171
+ uc-micro-py==1.0.1
172
+ unicodedata2==15.0.0
173
+ urllib3==1.26.12
174
+ uvicorn==0.20.0
175
+ wandb==0.15.5
176
+ wcwidth==0.2.5
177
+ websockets==10.4
178
+ Werkzeug==2.2.2
179
+ wheel==0.37.1
180
+ win32-setctime==1.1.0
181
+ wincertstore==0.2
182
+ xtcocotools==1.12
183
+ xxhash==3.2.0
184
+ yapf==0.32.0
185
+ yarl==1.8.2
186
+ zipp==3.11.0
vae/pretrain_vae/models--gsdf--Counterfeit-V2.5/refs/main ADDED
@@ -0,0 +1 @@
 
 
1
+ 93c5412baf37cbfa23a3278f7b33b0328db581fb
vae/pretrain_vae/models--gsdf--Counterfeit-V2.5/snapshots/93c5412baf37cbfa23a3278f7b33b0328db581fb/vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.10.2",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 256,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
vae/pretrain_vae/models--gsdf--Counterfeit-V2.5/snapshots/93c5412baf37cbfa23a3278f7b33b0328db581fb/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af03509f25bf282de98626830ef4fa607e596d0d0fbda8f1d6f5ccaa1d334640
3
+ size 334643276
weight/unet_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:598d60a65f5463df4c3c33879c887c5029b41a60b52c4d1481f99e47548b8ff2
3
+ size 857352782