• 大小: 0.17M
    文件类型: .rar
    金币: 1
    下载: 0 次
    发布日期: 2024-05-08
  • 语言: Python
  • 标签: mnist  

资源简介

人工智能算法实现mnist手写数字识别

资源截图

代码片段和文件信息

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D BatchNormalization Activation MaxPool2D Dropout Flatten Dense
from tensorflow.keras import Model

np.set_printoptions(threshold=np.inf)
mnist = tf.keras.datasets.mnist
#fashion = tf.keras.datasets.fashion_mnist
(x_train y_train) (x_test y_test) = mnist.load_data()
x_train x_test = x_train / 255.0 x_test / 255.0
print(“x_train.shape“ x_train.shape)
x_train = x_train.reshape(x_train.shape[0] 28 28 1)  # 给数据增加一个维度,使数据和网络结构匹配
x_test = x_test.reshape(x_test.shape[0] 28 28 1)
print(“x_train.shape“ x_train.shape)


class baseline(Model):
    def __init__(self):
        super(baseline self).__init__()
        self.c1 = Conv2D(filters=6 kernel_size=(5 5) padding=‘same‘)  # 卷积层
        self.b1 = BatchNormalization()  # BN层
        self.a1 = Activation(‘relu‘)  # 激活层
        self.p1 = MaxPool2D(pool_size=(2 2) strides=2 padding=‘same‘)  # 池化层
        self.d1 = Dropout(0.2)  # dropout层

        self.flatten = Flatten()
        self.f1 = Dense(128 activation=‘relu‘)
        self.d2 = Dropout(0.2)
        self.f2 = Dense(10 activation=‘softmax‘)

    def call(self x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y


model = baseline()

model.compile(optimizer=‘adam‘
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
              metrics=[‘sparse_categorical_accuracy‘])

checkpoint_save_path = “./checkpoint/baseline.ckpt“
if os.path.exists(checkpoint_save_path + ‘.index‘):
    print(‘-------------load the model-----------------‘)
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path
                                                 save_weights_only=True
                                                 save_best_only=True)

history = model.fit(x_train y_train batch_size=32 epochs=5 validation_data=(x_test y_test) validation_freq=1
                    callbacks=[cp_callback])
model.summary()

# print(model.trainable_variables)
file = open(‘./weights.txt‘ ‘w‘)
for v in model.trainable_variables:
    file.write(str(v.name) + ‘\n‘)
    file.write(str(v.shape) + ‘\n‘)
    file.write(str(v.numpy()) + ‘\n‘)
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history[‘sparse_categorical_accuracy‘]
val_acc = history.history[‘val_sparse_categorical_accuracy‘]
loss = history.history[‘loss‘]
val_loss = history.history[‘val_loss‘]

plt.subplot(1 2 1)
plt.plot(acc label=‘Training Accuracy‘)
plt.plot(val_acc label=‘Validation Accuracy

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件       3330  2020-07-11 03:54  mnist实验报告\ML_MNIST\CNN.py

     文件       1495  2020-07-11 04:48  mnist实验报告\ML_MNIST\FCN.py

     文件       1843  2020-07-11 05:38  mnist实验报告\ML_MNIST\RNN.py

     文件     256512  2020-11-13 22:11  mnist实验报告\mnist实验报告.doc

     目录          0  2020-07-11 06:00  mnist实验报告\ML_MNIST

     目录          0  2020-11-13 22:11  mnist实验报告

----------- ---------  ---------- -----  ----

               263180                    6


评论

共有 条评论