• 大小: 2.19MB
    文件类型: .rar
    金币: 1
    下载: 0 次
    发布日期: 2023-08-13
  • 语言: 其他
  • 标签: RtFrecords  

资源简介

tensorflow下 自制rfrecords数据集采用one-hot编码做图像分类源码

资源截图

代码片段和文件信息

# -*- coding: utf-8 -*-
“““
Created on Sat Feb 23 23:21:44 2019

@author: Administrator
“““

import tensorflow as tf
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from RTFrcord_read_data import read_and_decode


############################################################################################
height=100
weight=100

#############################################################################################
batch_size=432
 
#定义初始化权重和偏置函数
def weight_variable(shape):
    return(tf.Variable(tf.random_normal(shapestddev=0.01)))
def bias_variable(shape):
    return(tf.Variable(tf.constant(0.1shape=shape)))
#定义输入数据和dropout占位符
X=tf.placeholder(tf.float32[batch_sizeheight weight3])
y_=tf.placeholder(tf.float32[batch_size8])
keep_pro=tf.placeholder(tf.float32)
 
#搭建网络
def model(Xkeep_pro):
    w1=weight_variable([55332])
    b1=bias_variable([32])
    conv1=tf.nn.relu(tf.nn.conv2d(Xw1strides=[1111]padding=‘SAME‘)+b1)
    pool1=tf.nn.max_pool(conv1ksize=[1441]strides=[1441]padding=‘SAME‘)
    
    w2=weight_variable([553264])
    b2=bias_variable([64])
    conv2=tf.nn.relu(tf.nn.conv2d(pool1w2strides=[1111]padding=‘SAME‘)+b2)
    pool2=tf.nn.max_pool(conv2ksize=[1441]strides=[1441]padding=‘SAME‘) 
    tensor=tf.reshape(pool2[batch_size-1])
    dim=tensor.get_shape()[1].value
    w3=weight_variable([dim1024])
    b3=bias_variable([1024])
    fc1=tf.nn.relu(tf.matmul(tensorw3)+b3)
    h_fc1=tf.nn.dropout(fc1keep_pro)
    w4=weight_variable([10248])
    b4=bias_variable([8])
    y_conv=tf.nn.softmax(tf.matmul(h_fc1w4)+b4)
    return(y_conv)
 
#定义网络,并设置损失函数和训练器
y_conv=model(Xkeep_pro)
cost=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv)reduction_indices=[1]))
train_step=tf.train.AdamOptimizer(0.001).minimize(cost)
#计算准确率
correct_prediction=tf.equal(tf.argmax(y_conv1)tf.argmax(y_1))
accuracy=tf.reduce_mean(tf.cast(correct_predictiontf.float32))

#读取tfrecords数据
imagelabel=read_and_decode(“train1.tfrecords“)
#定义会话,并开始训练
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    #定义多线程
    coord=tf.train.Coordinator()
    threads=tf.train.start_queue_runners(coord=coord)
    #定义训练图像和标签
    example=np.zeros((batch_sizeheight weight3))
    l=np.zeros((batch_size1))
    try:
        #将数据存入example和l并将转化成one_hot形式
        for epoch in range(batch_size):
            example[epoch]l[epoch]=sess.run([imagelabel])
        print(l)  
       
        enc=OneHotEncoder()
        l=enc.fit_transform(l)
        l=l.toarray()
        print(l)
        for i in range(100):
            #开始训练
            sess.run(train_stepfeed_dict={X:exampley_:lkeep_pro:0.5})
            if i%10==0:
                print(‘train step‘‘%04d ‘ %(i+1)‘Accuracy=‘sess.run(accuracyfeed_dict={X:exampley_:lkeep_pro:0.5}))
    except tf.errors.OutOfRangeError:
        print(‘done!‘)
    finally:
   

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件       3234  2019-02-23 23:47  RTFrcords\data_classification.py

     文件       1743  2019-02-23 23:47  RTFrcords\RTFrcord_read_data.py

     文件       1819  2019-02-23 23:08  RTFrcords\RTFrcord_save_data.py

     文件   13016413  2019-02-23 23:09  RTFrcords\train1.tfrecords

     文件        893  2019-02-23 23:27  RTFrcords\__pycache__\RTFrcord_read_data.cpython-36.pyc

     目录          0  2019-02-23 23:27  RTFrcords\__pycache__

     目录          0  2019-02-23 23:27  RTFrcords

----------- ---------  ---------- -----  ----

             13024102                    7


评论

共有 条评论

相关资源