• 大小: 4KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-06-15
  • 语言: Python
  • 标签: gan  tensorflow  mnist  

资源简介

Tensorflow实现GAN生成mnist手写数字图片。 教程见:https://blog.csdn.net/u012223913/article/details/75051516

资源截图

代码片段和文件信息

# -*- coding: utf-8 -*-
# @Author: adrianna
# @Date:   2017-07-12 10:47:57
# @Last Modified by:   adrianna
# @Last Modified time: 2017-07-13 14:43:05

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_data

os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘

sess = tf.InteractiveSession()

mb_size = 128
Z_dim = 100

mnist = input_data.read_data_sets(‘../../MNIST_data‘ one_hot=True)


def weight_var(shape name):
    return tf.get_variable(name=name shape=shape initializer=tf.contrib.layers.xavier_initializer())


def bias_var(shape name):
    return tf.get_variable(name=name shape=shape initializer=tf.constant_initializer(0))


# discriminater net

X = tf.placeholder(tf.float32 shape=[None 784] name=‘X‘)

D_W1 = weight_var([784 128] ‘D_W1‘)
D_b1 = bias_var([128] ‘D_b1‘)

D_W2 = weight_var([128 1] ‘D_W2‘)
D_b2 = bias_var([1] ‘D_b2‘)


theta_D = [D_W1 D_W2 D_b1 D_b2]


# generator net

Z = tf.placeholder(tf.float32 shape=[None 100] name=‘Z‘)

G_W1 = weight_var([100 128] ‘G_W1‘)
G_b1 = bias_var([128] ‘G_B1‘)

G_W2 = weight_var([128 784] ‘G_W2‘)
G_b2 = bias_var([784] ‘G_B2‘)

theta_G = [G_W1 G_W2 G_b1 G_b2]


def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1 G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob


def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x D_W1) + D_b1)
    D_logit = tf.matmul(D_h1 D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)
    return D_prob D_logit


G_sample = generator(Z)
D_real D_logit_real = discriminator(X)
D_fake D_logit_fake = discriminator(G_sample)


# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    log

评论

共有 条评论