import jittor as jt from jittor import nn import argparse import os import numpy as np from jittor.dataset.mnist import MNIST import jittor.transform as transform import cv2 import time from jittor.dataset.dataset import ImageFolder jt.flags.use_cuda = 1 save_img_path = './images_celebA' save_model_path = './save_model_celebA' os.makedirs(save_img_path, exist_ok=True) os.makedirs(save_model_path, exist_ok=True) parser = argparse.ArgumentParser() parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数') parser.add_argument('--batch_size', type=int, default=128, help='批次大小') parser.add_argument('--lr', type=float, default=0.0002, help='学习率') parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减') parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减') parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数') parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度') parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小') parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数') parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数') parser.add_argument('--n_critic', type=int, default=5, help='每个迭代器的鉴别器训练步骤数') parser.add_argument('--clip_value', type=float, default=0.01, help='光盘的上下剪辑值。 权重') parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔') parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型') parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址') opt = parser.parse_args() print(opt) img_shape = (opt.celebA_channels, opt.img_size, opt.img_size) # 训练集加载程序 def DataLoader(dataclass, img_size, batch_size, train_dir): if dataclass == 'MNIST': Transform = transform.Compose([ transform.Resize(size=img_size), transform.Gray(), transform.ImageNormalize(mean=[0.5], std=[0.5])]) train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True) elif dataclass == 'celebA': Transform = transform.Compose([ transform.Resize(size=img_size), transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])]) train_loader = ImageFolder(train_dir)\ .set_attrs(transform=Transform, batch_size=batch_size, shuffle=True) else: print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass) dataclass = input("请输入:MNIST或者celebA:") DataLoader(dataclass, img_size, batch_size,train_dir) return train_loader dataloader = DataLoader(opt.task,opt.img_size,opt.batch_size,opt.train_dir) # 保存图片 def save_image(img, path, nrow=10): N,C,W,H = img.shape img2=img.reshape([-1,W*nrow*nrow,H]) img=img2[:,:W*nrow,:] for i in range(1,nrow): img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2) min_=img.min() max_=img.max() img=(img-min_)/(max_-min_)*255 img=img.transpose((1,2,0)) cv2.imwrite(path,img) # 生成器 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2)) return layers self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh()) def execute(self, z): img = self.model(z) img = img.view((img.shape[0], *img_shape)) return img # 判别器 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), ) def execute(self, img): img_flat = img.reshape((img.shape[0], (- 1))) validity = self.model(img_flat) return validity lambda_gp = 10 # 初始化生成器和判别器 generator = Generator() discriminator = Discriminator() # 优化器 optimizer_G = jt.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = jt.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # 损失函数(计算 WGAN GP 的梯度惩罚损失) def compute_gradient_penalty(D, real_samples, fake_samples): alpha = jt.array(np.random.random((real_samples.shape[0], 1, 1, 1)).astype('float32')) interpolates = ((alpha * real_samples) + ((1 - alpha) * fake_samples)) d_interpolates = D(interpolates) gradients = jt.grad(d_interpolates, interpolates) gradients = gradients.reshape((gradients.shape[0], (- 1))) gp =((jt.sqrt((gradients.sqr()).sum(1))-1).sqr()).mean() return gp batches_done = 0 warmup_times = -1 run_times = 3000 total_time = 0. cnt = 0 # ---------- # 训练 # ---------- for epoch in range(opt.n_epochs):# 200 for i, (imgs, _) in enumerate(dataloader): real_imgs = jt.array(imgs).float32() # ----------------- # 训练生成器 # ----------------- z = jt.array((np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).astype('float32')) fake_imgs = generator(z) real_validity = discriminator(real_imgs) fake_validity = discriminator(fake_imgs) gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs) d_loss = (- real_validity.mean() + fake_validity.mean() + lambda_gp * gradient_penalty) d_loss.sync() optimizer_D.step(d_loss) # --------------------- # 训练判别器 # --------------------- if ((i % opt.n_critic) == 0): fake_img = generator(z) fake_validityg = discriminator(fake_img) g_loss = -fake_validityg.mean() g_loss.sync() optimizer_G.step(g_loss) if warmup_times==-1: print(('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data))) #if ((batches_done % opt.sample_interval) == 0): if ( i == 1583 ):#根据opt.batch_size而变化,每批次保存一次 save_image(fake_imgs.data[:25], ('%s/%d.png' % (save_img_path, batches_done)), nrow=5) batches_done += opt.n_critic if warmup_times!=-1: jt.sync_all() cnt += 1 print(cnt) if cnt == warmup_times: jt.sync_all(True) sta = time.time() if cnt > warmup_times + run_times: jt.sync_all(True) total_time = time.time() - sta print(f"run {run_times} iters cost {total_time} seconds, and avg {total_time / run_times} one iter.") exit(0) if epoch % 10 == 0:# 0-199 generator.save("%s/generator_%s.pkl"%(save_model_path, opt.task)) discriminator.save("%s/discriminator_%s.pkl"%(save_model_path, opt.task))