• 大小: 5KB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2021-06-03
  • 语言: Python
  • 标签: 数字识别  CNN  

资源简介

基于卷积神经网络的手写数字识别,工具使用Google的人工智能TensorFlow库,识别准确率高,代码使用python3.0以上版本

资源截图

代码片段和文件信息

import tensorflow as tf
import numpy as np
from PIL import Image
import forward as fw
import backward as bw
import time
import os

os.environ[‘TF_CPP_MIN_LOG_LEVEL‘]=‘2‘
def restore_model(testPivArr):
    with tf.Graph().as_default() as g:
        x= tf.placeholder(tf.float32[1fw.IMAGE_SIZE fw.IMAGE_SIZE 1])
        # x = tf.placeholder(tf.float32 [1784])
        y= fw.forward(xFalseNone)
        preValue= tf.argmax(y1)

        variable_averages= tf.train.ExponentialMovingAverage(bw.MOVING_AVERAGE_DECAY)
        variable_to_restore= variable_averages.variables_to_restore()
        saver= tf.train.Saver(variable_to_restore)

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(bw.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess ckpt.model_checkpoint_path)
                ypreValue= sess.run([ypreValue]feed_dict={x:testPivArr})
                return preValue
            else:
                print(‘没有找到训练节点‘)
                return -1

def pre_pic(picName):
    img= Image.open(picName)
    im_arr= np.array(img.convert(‘L‘))
    threshold=50
    for i in range(fw.IMAGE_SIZE):
        for j in range(fw.IMAGE_SIZE):
            im_arr[i][j]=255- im_arr[i][j]
            if im_arr[i][j]                im_arr[i][j]=0
            else:
                im_arr[i][j] = 255
    nm_arr= im_arr.reshape([1784])
    nm_arr= nm_arr.astype(np.float32)
    img_ready= np.dot(nm_arr1.0/255.0)
    img_ready=np.reshape(img_ready[-1fw.IMAGE_SIZE fw.IMAGE_SIZE 1])
    return img_ready

def application():
    testPath = input(“input the folder of test pictures:“)
    start = time.clock()
    for filename in os.listdir(testPath):  # listdir的参数是文件夹的路径
        print(“the name of test picture:“filename)
        testPicArr = pre_pic(testPath+“/“+filename)
        preValue = restore_model(testPicArr)
        print(“预测数字可能是:“preValue)
        print(‘###############################################‘)
    elapsed = (time.clock() - start)
    print(‘预测所用时间:‘elapsed)
def main():
    # mnist = input_data.read_data_sets(‘MNIST_data‘ one_hot=True)
    application()
if __name__==‘__main__‘:
    main()

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     文件        2649  2018-05-05 17:31  backward.py
     文件        2218  2018-05-05 17:31  forward.py
     文件        1435  2018-04-17 13:57  handWriteNumber.py
     文件        1881  2018-05-05 17:31  test.py
     文件        2343  2018-05-05 17:31  app.py

评论

共有 条评论