资源简介

梯度下降纯手工实现 MLP CNN RNN SEQ2SEQ识别手写体MNIST数据集十分类问题代码详解.

资源截图

代码片段和文件信息

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data  # 导入下载数据集手写体

mnist = input_data.read_data_sets(‘../MNIST_data/‘ one_hot=True)  # 下载数据集


class CNNNet:
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32 shape=[None 28 28 1] name=‘input_x‘)
        self.y = tf.placeholder(dtype=tf.float32 shape=[None 10] name=‘input_y‘)

        self.w1 = tf.Variable(
            tf.truncated_normal(shape=[3 3 1 16] dtype=tf.float32 stddev=tf.sqrt(1 / 16) name=‘w1‘))
        self.b1 = tf.Variable(tf.zeros(shape=[16] dtype=tf.float32 name=‘b1‘))

        self.w2 = tf.Variable(
            tf.truncated_normal(shape=[3 3 16 32] dtype=tf.float32 stddev=tf.sqrt(1 / 32) name=‘w2‘))
        self.b2 = tf.Variable(tf.zeros(shape=[32] dtype=tf.float32 name=‘b2‘))

        self.fc_w1 = tf.Variable(
            tf.truncated_normal(shape=[28 * 28 * 32 128] dtype=tf.float32 stddev=tf.sqrt(1 / 128) name=‘fc_w1‘))
        self.fc_b1 = tf.Variable(tf.zeros(shape=[128] dtype=tf.float32 name=‘fc_b1‘))

        self.fc_w2 = tf.Variable(
            tf.truncated_normal(shape=[128 10] dtype=tf.float32 stddev=tf.sqrt(1 / 10) name=‘fc_w2‘))
        self.fc_b2 = tf.Variable(tf.zeros(shape=[10] dtype=tf.float32 name=‘fc_b2‘))

    def forward(self):
        self.conv1 = tf.nn.relu(
            tf.nn.conv2d(self.x self.w1 strides=[1 1 1 1] padding=‘SAME‘ name=‘conv1‘) + self.b1)
        self.conv2 = tf.nn.relu(
            tf.nn.conv2d(self.conv1 self.w2 strides=[1 1 1 1] padding=‘SAME‘ name=‘conv2‘) + self.b2)
        self.flat = tf.reshape(self.conv2 [-1 28 * 28 * 32])
        self.fc1 = tf.nn.relu(tf.matmul(self.flat self.fc_w1) + self.fc_b1)
        self.fc2 = tf.matmul(self.fc1 self.fc_w2) + self.fc_b2
        self.output = tf.nn.softmax(self.fc2)

    def backward(self):
        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.fc2 labels=self.y))
        self.opt = tf.train.AdamOptimizer().minimize(self.cost)

    def acc(self):
        self.acc2 = tf.equal(tf.argmax(self.output 1) tf.argmax(self.y 1))
        self.accaracy = tf.reduce_mean(tf.cast(self.acc2 dtype=tf.float32))


if __name__ == ‘__main__‘:
    net = CNNNet()
    net.forward()
    net.backward()
    net.acc()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for i in range(10000):
            ax ay = mnist.train.next_batch(100)
            ax_batch = ax.reshape([-1 28 28 1])
            loss output accaracy _ = sess.run(fetches=[net.cost net.output net.accaracy net.opt]
                                                 feed_dict={net.x: ax_batch net.y: ay})
            # print(loss)
            # print(accaracy)
            if i % 10 == 0:
                test_ax test_ay = mnist.test.next_batch(100)
                test_ax_batch = test_

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件      21566  2018-11-02 22:06  gradient_descent.png

     文件       1849  2018-11-01 11:33  gradient_descent.py

     文件    1648877  2018-10-30 09:53  MNIST_data\t10k-images-idx3-ubyte.gz

     文件       4542  2018-10-30 09:53  MNIST_data\t10k-labels-idx1-ubyte.gz

     文件    9912422  2018-10-30 09:53  MNIST_data\train-images-idx3-ubyte.gz

     文件      28881  2018-10-30 09:53  MNIST_data\train-labels-idx1-ubyte.gz

     文件       3213  2018-11-02 11:21  SEQ2SEQ.py

     文件       2640  2018-11-01 19:56  RNNNet.py

     文件       3357  2018-11-01 19:56  CNNNet.py

     文件       2205  2018-11-01 12:44  MLPNet.py

     目录          0  2018-11-02 22:07  MNIST_data

----------- ---------  ---------- -----  ----

             11629552                    11


评论

共有 条评论