资源简介

神经网络入门代码,keras实现,MNIST数据集识别,详情见博客:http://blog.csdn.net/adamshan/article/details/79004784

资源截图

代码片段和文件信息

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from matplotlib import pyplot as plt

def read_data(num_classes):
    # the data shuffled and split between train and test sets
    (x_train y_train) (x_test y_test) = mnist.load_data()

    x_train = x_train.reshape(60000 784)
    x_test = x_test.reshape(10000 784)
    x_train = x_train.astype(‘float32‘)
    x_test = x_test.astype(‘float32‘)
    x_train /= 255
    x_test /= 255
    print(x_train.shape[0] ‘train samples‘)
    print(x_test.shape[0] ‘test samples‘)

    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train num_classes)
    y_test = keras.utils.to_categorical(y_test num_classes)
    return x_train y_train x_test y_test

def model(x_train y_train x_test y_test batch_size epochs num_classes):
    model = Sequential()
    model.add(Dense(15 activation=‘relu‘ input_shape=(784)))
    model.add(Dense(num_classes activation=‘softmax‘))

    model.summary()

    model.compile(loss=‘categorical_crossentropy‘
                  optimizer=SGD(lr=0.01)
                  metrics=[‘accuracy‘])

    history = model.fit(x_train y_train
                        batch_size=batch_size
                        epochs=epochs
                        verbose=1
                        validation_data=(x_test y_test))

    ### print the keys contained in the history object
    print(history.history.keys())
    plot_training(history=history)
    model.save(‘model.json‘)

    score = model.evaluate(x_test y_test verbose=0)
    print(‘Test loss:‘ score[0])
    print(‘Test accuracy:‘ score[1])


def plot_training(history):
    ### plot the training and validation loss for each epoch
    plt.plot(history.history[‘loss‘])
    plt.plot(history.history[‘val_loss‘])
    plt.title(‘model mean squared error loss‘)
    plt.ylabel(‘mean squared error loss‘)
    plt.xlabel(‘epoch‘)
    plt.legend([‘training set‘ ‘validation set‘] loc=‘upper right‘)
    plt.show()


def show_samples(samples labels):
    plt.figure(figsize=(12 12))
    for i in range(len(samples)):
        plt.subplot(4 4 i+1)
        plt.imshow(samples[i] cmap=‘gray‘)
        plt.title(labels[i])
    plt.show()


if __name__ == ‘__main__‘:
    batch_size = 128
    num_classes = 10
    epochs = 20

    x_train y_train x_test y_test = read_data(num_classes)

    model(x_train y_train x_test y_test batch_size epochs num_classes)




 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2018-01-08 16:49  code\
     文件     5382352  2018-01-08 16:49  code\model.json
     文件        2588  2018-01-08 16:49  code\model.py
     目录           0  2018-01-08 16:49  code\.ipynb_checkpoints\
     文件      330403  2018-01-08 16:49  code\.ipynb_checkpoints\Untitled-checkpoint.ipynb
     文件      330403  2018-01-08 16:49  code\model.ipynb
     目录           0  2018-01-08 16:49  code\.idea\
     目录           0  2018-01-08 16:49  code\.idea\inspectionProfiles\
     文件         562  2018-01-08 16:49  code\.idea\inspectionProfiles\Project_Default.xml
     文件         459  2018-01-08 16:49  code\.idea\code.iml
     文件       12139  2018-01-08 16:49  code\.idea\workspace.xml
     文件         260  2018-01-08 16:49  code\.idea\modules.xml
     文件         209  2018-01-08 16:49  code\.idea\misc.xml

评论

共有 条评论