• 大小: 5KB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2021-05-29
  • 语言: 其他
  • 标签: 深度学习  Gan  

资源简介

使用Wgan生成二次元人物头像,train部分代码写的不全,自己根据需求补就行了,数据就使用李宏毅网课上提供的数据,太大了上传不了,网络使用的是DenseNet

资源截图

代码片段和文件信息

from GAN.Discriminator import discriminator
from GAN.Generator import generator
from GAN.CreateData import read_data
import tensorflow as tf
import numpy as np
from PIL import Image

true_image = read_data(‘./dataset/CartoonCharacters/faces‘)

batch_size = 64
learning_rate = 1e-4
epochs = 100
train_steps = int(true_image.__len__()/batch_size) + 1

if __name__ == ‘__main__‘:
    is_training = tf.placeholder(tf.bool)
    dropout_rate = tf.placeholder(tf.float32)
    X = tf.placeholder(tf.float32[batch_size96963])

    G = generator(growth_rate_K=12is_training=is_trainingdropout_rate=dropout_rate).generator(batch_size=batch_size)
    D_real = discriminator(growth_rate_K=12is_training=is_trainingdropout_rate=dropout_rate).discriminator(input=X)
    D_fake = discriminator(growth_rate_K=12is_training=is_trainingdropout_rate=dropout_rate).discriminator(input=G)

    G_cost = -tf.reduce_mean(D_fake)
    D_cost = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        G_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(G_cost)
        D_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(D_cost)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    for epoch in range(epochs):
        for i in range(train_steps):
            start = i*batch_size
            end = min((i+1)*batch_sizetrue_image.__len__())
            batch = true_image[start:end]
            feed_dict_real = {
                is_training:True
                dropout_rate:0.2
                X:batch
            }
            feed_dict_fake = {
                is_training:True
                dropout_rate:0.2
            }
            for j in range(05):
                _ D_loss = sess.run([D_train_opD_cost]feed_dict=feed_dict_real)
            _ G_loss = sess.run([G_train_opG_cost]feed_dict=feed_dict_fake)

            print(‘Epochs: {} training: {} D_cost: {} G_cost: {}‘.format(epoch i D_loss G_loss))

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2019-09-27 16:48  tensorflow-GAN-CreateCartoonFaces\
     目录           0  2019-09-27 17:00  tensorflow-GAN-CreateCartoonFaces\dataset\
     目录           0  2019-09-27 17:00  tensorflow-GAN-CreateCartoonFaces\dataset\CartoonCharacters\
     文件        2102  2019-09-27 16:48  tensorflow-GAN-CreateCartoonFaces\train.py
     目录           0  2019-09-27 16:52  tensorflow-GAN-CreateCartoonFaces\Network\
     文件           0  2019-09-24 13:37  tensorflow-GAN-CreateCartoonFaces\Network\__init__.py
     文件        7218  2019-09-27 13:36  tensorflow-GAN-CreateCartoonFaces\Network\DenseNet.py
     目录           0  2019-09-27 16:49  tensorflow-GAN-CreateCartoonFaces\GAN\
     文件           0  2019-09-27 12:34  tensorflow-GAN-CreateCartoonFaces\GAN\__init__.py
     文件         290  2019-09-27 16:49  tensorflow-GAN-CreateCartoonFaces\GAN\CreateData.py
     文件         319  2019-09-27 15:11  tensorflow-GAN-CreateCartoonFaces\GAN\Discriminator.py
     文件         825  2019-09-27 16:49  tensorflow-GAN-CreateCartoonFaces\GAN\Generator.py

评论

共有 条评论