• 大小: 4KB
    文件类型: .py
    金币: 2
    下载: 1 次
    发布日期: 2021-06-10
  • 语言: Python
  • 标签: rnn  mnist  

资源简介

使用RNN进行mnist的分类,使用的是一个3层的GRU作为模型

资源截图

代码片段和文件信息

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def initialize_weight_bias(in_size out_size):
    weight = tf.truncated_normal(shape=(in_size out_size) stddev=0.01 mean=0.0)
    bias = tf.constant(0.1 shape=[out_size])
    return tf.Variable(weight) tf.Variable(bias)


def model(data target dropout num_hidden=200 num_layers=3):
    “““
    RNN model for mnist classification.
    Args:
        data: input data with shape (batch_size max_time_steps cell_size).
        target : label of input data with shape (batch_size num_classes).
        dropout: dropout rate.
        num_hidden: the number of hidden units.
        num_layers: the number of RNN layers.

    Returns:

    “““
    # establish RNN model
    cells = list()
    for _ in range(num_layers):
        cell = tf.nn.rnn_cell.GRUCell(num_units=num_hidden)
        cell = tf.nn.rnn_cell.DropoutWrapper(cell=cell output_keep_prob=1.0-dropout)
        cells.append(cell)
    network = tf.nn.rnn_cell.MultiRNNCell(cells=cells)
    outputs last_state = tf.nn.dynamic_rnn(cell=network inputs=data dtype=tf.float32)

    # get last output
    outputs = tf.transpose(outputs (1 0 2))
    last_output = tf.gather(outputs int(outputs.get_shape()[0])-1)

    # add softmax layer
    out_size = int(target.get_shape()[1])
    weight bias = initialize_weight_bias(in_size=num_hidden out_size=out_size)
    logits = tf.add(tf.matmul(last_output weight) bias)

    return logits


def main():
    # define some parameters
    default_epochs = 10
    default_batch_size = 64
    default_dropout = 0.5
    test_freq = 150  # every 150 batches
    logs_path = ‘data/log‘

    # get train and test data
    mnist_data = input_data.read_data_sets(‘data/mnist‘ one_hot=True)
    total_steps = int(mnist_data.train.num_examples/default_batch_size)
    total_test_steps = int(mnist_data.test.num_examples/default_batch_size)
    print(‘number of training examples: %d‘ % mnist_data.train.num_examples)  # 55000
    print(‘number of test examples: %d‘ % mnist_data.test.num_examples)  # 10000

    # fit RNN model
    input_x = tf.placeholder(tf.float32 shape=(None 28 28))
    

评论

共有 条评论