基于MNIST的GANs实现【Pytorch】 太过爱你忘了你带给我的痛 2022-04-17 01:38 242阅读 0赞 ## 简述 ## 其实是根据我之前写的两个代码改的。(之前已经有过非常详细的解释了,可以去看看) * [【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)][GANs_pytorch-GANs_-] * [【Gans入门】Pytorch实现Gans代码详解【70+代码】][Gans_Pytorch_Gans_70] 同时,在结合了我之前写的DCGANs的时候,实现的一份代码 * [(深度卷积生成对抗神经网络)DCGANs论文阅读与实现pytorch][DCGANs_pytorch] MNIST上选特定的数值,是根据下面的这篇文章得到的。 * [MNIST选取特定数值的训练集][MNIST] 之前的代码上都有非常详细的解释。这里只是基于上面的一点点改进而已。就不给出特别详细的解释。但是代码中任然保留有注释部分。 ## 图形演变过程 ## ![在这里插入图片描述][20181111101752788.gif] ## 代码 ## import torch import torch.nn as nn import torchvision import torch.utils.data as Data import matplotlib.pyplot as plt import os import shutil import imageio PNGFILE = './png/' if not os.path.exists(PNGFILE): os.mkdir(PNGFILE) else: shutil.rmtree(PNGFILE) os.mkdir(PNGFILE) # Hyper Parameters BATCH_SIZE = 64 LR_G = 0.0001 # learning rate for generator LR_D = 0.0001 # learning rate for discriminator N_IDEAS = 100 # think of this as number of ideas for generating an art work (Generator) target_num = 0 # target Number EPOCH = 10 # 训练整批数据多少次 DOWNLOAD_MNIST = False # 已经下载好的话,会自动跳过的 ART_COMPONENTS = 28 * 28 # Mnist 手写数字 class myMNIST(torchvision.datasets.MNIST): def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None): super(myMNIST, self).__init__( root, train=train, transform=transform, target_transform=target_transform, download=download) if targetNum != None: self.train_data = self.train_data[self.train_labels == targetNum] self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE] self.train_labels = self.train_labels[self.train_labels == targetNum][ :int(self.__len__() / BATCH_SIZE) * BATCH_SIZE] def __len__(self): if self.train: return self.train_data.shape[0] else: return 10000 train_data = myMNIST( root='./mnist/', # 保存或者提取位置 train=True, # this is training data transform=torchvision.transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成 # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间 download=DOWNLOAD_MNIST, # 没下载就下载, 下载了就不用再下了 targetNum=target_num ) print(len(train_data)) # print(train_data.shape) # 训练集丢BATCH_SIZE个, 图片大小为28*28 train_loader = Data.DataLoader( dataset=train_data, batch_size=BATCH_SIZE, shuffle=True # 是否打乱顺序 ) G = nn.Sequential( # Generator nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution) nn.ReLU(), nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas nn.ReLU(), ) D = nn.Sequential( # Discriminator nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid(), # tell the probability that the art work is made by artist ) # loss & optimizer optimD = torch.optim.Adam(D.parameters(), lr=LR_D) optimG = torch.optim.Adam(G.parameters(), lr=LR_G) label_Real = torch.FloatTensor(BATCH_SIZE).data.fill_(1) label_Fake = torch.FloatTensor(BATCH_SIZE).data.fill_(0) filePath = [] for epoch in range(EPOCH): for step, (images, imagesLabel) in enumerate(train_loader): G_ideas = torch.randn((BATCH_SIZE, N_IDEAS)) G_paintings = G(G_ideas) images = images.reshape(BATCH_SIZE, -1) prob_artist0 = D(images) # D try to increase this prob prob_artist1 = D(G_paintings) D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1)) G_loss = torch.mean(torch.log(1. - prob_artist1)) optimD.zero_grad() D_loss.backward(retain_graph=True) optimD.step() optimG.zero_grad() G_loss.backward(retain_graph=True) optimG.step() if step % 20 == 0: plt.cla() picture = torch.squeeze(G_paintings[0]).detach().numpy().reshape((28, 28)) plt.imshow(picture, cmap=plt.cm.gray_r) plt.savefig(PNGFILE + '%d-%d.png' % (epoch, step)) filePath.append(PNGFILE + '%d-%d.png' % (epoch, step)) generated_images = [] for png_path in filePath: generated_images.append(imageio.imread(png_path)) shutil.rmtree(PNGFILE) imageio.mimsave('gan-mnist.gif', generated_images, 'GIF', duration=0.1) [GANs_pytorch-GANs_-]: https://blog.csdn.net/a19990412/article/details/83856083 [Gans_Pytorch_Gans_70]: https://blog.csdn.net/a19990412/article/details/83744225 [DCGANs_pytorch]: https://blog.csdn.net/a19990412/article/details/83928414 [MNIST]: https://blog.csdn.net/a19990412/article/details/83934447 [20181111101752788.gif]: /images/20220417/aefde3c96eb4409083cb3f060766610e.png
还没有评论,来说两句吧...