• 大小: 11.06MB
    文件类型: .rar
    金币: 1
    下载: 0 次
    发布日期: 2023-08-04
  • 语言: 其他
  • 标签: GAN  图像处理  

资源简介

生成对抗网络(GAN)实例 代码+数据集
很实用的代码,并且简单易学,对深度学习感兴趣的可以看看
数据集有手写图片的识别,也可以替换成自己的数据集

资源截图

代码片段和文件信息

import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets(‘../../MNIST_data‘ one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
c = 0
lr = 1e-3


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev requires_grad=True)


“““ ==================== GENERATOR ======================== “““

Wzh = xavier_init(size=[Z_dim h_dim])
bzh = Variable(torch.zeros(h_dim) requires_grad=True)

Whx = xavier_init(size=[h_dim X_dim])
bhx = Variable(torch.zeros(X_dim) requires_grad=True)


def G(z):
    h = nn.relu(z @ Wzh + bzh.repeat(z.size(0) 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0) 1))
    return X


“““ ==================== DISCRIMINATOR ======================== “““

Wxh = xavier_init(size=[X_dim h_dim])
bxh = Variable(torch.zeros(h_dim) requires_grad=True)

Why = xavier_init(size=[h_dim 1])
bhy = Variable(torch.zeros(1) requires_grad=True)


def D(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0) 1))
    y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0) 1))
    return y


G_params = [Wzh bzh Whx bhx]
D_params = [Wxh bxh Why bhy]
params = G_params + D_params


“““ ===================== TRAINING ======================== “““


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params lr=1e-3)
D_solver = optim.Adam(D_params lr=1e-3)

ones_label = Variable(torch.ones(mb_size 1))
zeros_label = Variable(torch.zeros(mb_size 1))


for it in range(100000):
    # Sample data
    z = Variable(torch.randn(mb_size Z_dim))
    X _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Dicriminator forward-loss-backward-update
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    D_loss_real = nn.binary_cross_entropy(D_real ones_label)
    D_loss_fake = nn.binary_cross_entropy(D_fake zeros_label)
    D_loss = D_loss_real + D_loss_fake

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size Z_dim))
    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = nn.binary_cross_entropy(D_fake ones_label)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print(‘Iter-{}; D_loss: {}

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件    1648877  2018-11-06 15:59  Demo\MNIST_data\t10k-images-idx3-ubyte.gz

     文件       4542  2018-11-06 15:59  Demo\MNIST_data\t10k-labels-idx1-ubyte.gz

     文件    9912422  2018-11-06 15:59  Demo\MNIST_data\train-images-idx3-ubyte.gz

     文件      28881  2018-11-06 15:59  Demo\MNIST_data\train-labels-idx1-ubyte.gz

     文件       3723  2018-12-01 15:16  Demo\vanilla_gan\gan_pytorch.py

     文件       3586  2018-12-02 10:16  Demo\vanilla_gan\gan_tensorflow.py

     目录          0  2018-12-05 21:28  Demo\vanilla_gan\out

     目录          0  2018-12-05 21:28  Demo\MNIST_data

     目录          0  2018-12-05 21:28  Demo\vanilla_gan

     目录          0  2018-12-05 21:28  Demo

----------- ---------  ---------- -----  ----

             11602031                    10


评论

共有 条评论