• 大小: 6KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-06-01
  • 语言: Python
  • 标签: VGG16  MNIST  

资源简介

使用VGG16网络实现对传统MNIST手写数据集的识别任务。

资源截图

代码片段和文件信息

#Create Wed May 2019-5-29 19:37:16
#End 2019-5-29 21:30:35

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets(‘MNIST/‘ one_hot = True)

x = tf.placeholder(tf.float32 [None 784])
y = tf.placeholder(tf.float32 [None 10])
keep_prob = tf.placeholder(tf.float32)

def conv2d(name x w b):
return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x w strides = [1 1 1 1] padding = ‘SAME‘) b) name = name)

def max_pool(name x):
return tf.nn.max_pool(x ksize = [1 2 2 1] strides = [1 2 2 1] padding = ‘SAME‘ name = name)

def norm(name x):
return tf.nn.lrn(x depth_radius = None bias = 0.01 alpha = 0.001 beta = 1.0 name = name)

weights = {
‘wc1‘: tf.Variable(tf.random_normal([3 3 1 64]))
‘wc2‘: tf.Variable(tf.random_normal([3 3 64 64]))
‘wc3‘: tf.Variable(tf.random_normal([3 3 64 128]))
‘wc4‘: tf.Variable(tf.random_normal([3 3 128 128]))
‘wc5‘: tf.Variable(tf.random_normal([3 3 128 256]))
‘wc6‘: tf.Variable(tf.random_normal([3 3 256 256]))
‘wc7‘: tf.Variable(tf.random_normal([3 3 256 256]))
‘wc8‘: tf.Variable(tf.random_normal([3 3 256 256]))
‘wc9‘: tf.Variable(tf.random_normal([3 3 256 512]))
‘wc10‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc11‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc12‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc13‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc14‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc15‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc16‘: tf.Variable(tf.random_normal([3 3 512 256]))

‘wd1‘: tf.Variable(tf.random_normal([4*4*256 4096]))
‘wd2‘: tf.Variable(tf.random_normal([4096 4096]))
‘out‘: tf.Variable(tf.random_normal([4096 10]))
}

biases = {
‘bc1‘: tf.Variable(tf.zeros([64]))
‘bc2‘: tf.Variable(tf.zeros([64]))
‘bc3‘: tf.Variable(tf.zeros([128]))
‘bc4‘: tf.Variable(tf.zeros([128]))
‘bc5‘: tf.Variable(tf.zeros([256]))
‘bc6‘: tf.Variable(tf.zeros([256]))
‘bc7‘: tf.Variable(tf.zeros([256]))
‘bc8‘: tf.Variable(tf.zeros([256]))
‘bc9‘: tf.Variable(tf.zeros([512]))
‘bc10‘: tf.Variable(tf.zeros([512]))
‘bc11‘: tf.Variable(tf.zeros([512]))
‘bc12‘: tf.Variable(tf.zeros([512]))
‘bc13‘: tf.Variable(tf.zeros([512]))
‘bc14‘: tf.Variable(tf.zeros([512]))
‘bc15‘: tf.Variable(tf.zeros([512]))
‘bc16‘: tf.Variable(tf.zeros([256]))

‘bd1‘: tf.Variable(tf.zeros([4096]))
‘bd2‘: tf.Variable(tf.zeros([4096]))
‘out‘: tf.Variable(tf.zeros([10]))
}

#2 4 12进行池化
def VGG16(x weights biases dropout):
x = tf.reshape(x shape = [-1 28 28 1])

conv1 = conv2d(‘conv1‘ x weights[‘wc1‘] biases[‘bc1‘])
#28*28*64
norm1 = norm(‘norm1‘ conv1)

conv2 = conv2d(‘conv2‘ norm1 weights[‘wc2‘] biases[‘b

评论

共有 条评论