资源简介

Deep Learning using Linear Support Vector Machines的简单实现代码

资源截图

代码片段和文件信息

# Copyright 2017 Abien Fred Agarap

# Licensed under the Apache License Version 2.0 (the “License“);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an “AS IS“ BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

“““Implementation of the CNN classes“““
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

__version__ = ‘0.1.0‘
__author__ = ‘Abien Fred Agarap‘

import argparse
from model.cnn_softmax import CNN
from model.cnn_svm import CNNSVM
from tensorflow.examples.tutorials.mnist import input_data


def parse_args():
    parser = argparse.ArgumentParser(description=‘CNN & CNN-SVM for Image Classification‘)
    group = parser.add_argument_group(‘Arguments‘)
    group.add_argument(‘-m‘ ‘--model‘ required=True type=str
                       help=‘[1] CNN-Softmax [2] CNN-SVM‘)
    group.add_argument(‘-d‘ ‘--dataset‘ required=True type=str
                       help=‘path of the MNIST dataset‘)
    group.add_argument(‘-p‘ ‘--penalty_parameter‘ required=False type=int
                       help=‘the SVM C penalty parameter‘)
    group.add_argument(‘-c‘ ‘--checkpoint_path‘ required=True type=str
                       help=‘path where to save the trained model‘)
    group.add_argument(‘-l‘ ‘--log_path‘ required=True type=str
                       help=‘path where to save the TensorBoard logs‘)
    arguments = parser.parse_args()
    return arguments


if __name__ == ‘__main__‘:
    args = parse_args()

    mnist = input_data.read_data_sets(args.dataset one_hot=True)
    num_classes = mnist.train.labels.shape[1]
    sequence_length = mnist.train.images.shape[1]
    model_choice = args.model

    assert model_choice == ‘1‘ or model_choice == ‘2‘ “Invalid choice: Choose between 1 and 2 only.“

    if model_choice == ‘1‘:
        model = CNN(alpha=1e-3 batch_size=128 num_classes=num_classes num_features=sequence_length)
        model.train(checkpoint_path=args.checkpoint_path epochs=10000 log_path=args.log_path
                    train_data=mnist.train test_data=mnist.test)
    elif model_choice == ‘2‘:
        model = CNNSVM(alpha=1e-3 batch_size=128 num_classes=num_classes num_features=sequence_length
                       penalty_parameter=args.penalty_parameter)
        model.train(checkpoint_path=args.checkpoint_path epochs=10000 log_path=args.log_path
                    train_data=mnist.train test_data=mnist.test)

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2017-12-30 11:44  cnn-svm-master\
     文件       11347  2017-12-30 11:44  cnn-svm-master\LICENSE
     文件        7218  2017-12-30 11:44  cnn-svm-master\README.md
     目录           0  2017-12-30 11:44  cnn-svm-master\figures\
     文件       19569  2017-12-30 11:44  cnn-svm-master\figures\accuracy-loss-fashion.png
     文件       15839  2017-12-30 11:44  cnn-svm-master\figures\accuracy-loss-mnist.png
     目录           0  2017-12-30 11:44  cnn-svm-master\logs\
     目录           0  2017-12-30 11:44  cnn-svm-master\logs\Sat Dec  9 13:38:20 2017-training\
     文件      153111  2017-12-30 11:44  cnn-svm-master\logs\Sat Dec  9 13:38:20 2017-training\events.out.tfevents.1512797900.darth-Inspiron7559
     目录           0  2017-12-30 11:44  cnn-svm-master\logs\Sat Dec  9 13:43:41 2017-training\
     文件      172715  2017-12-30 11:44  cnn-svm-master\logs\Sat Dec  9 13:43:41 2017-training\events.out.tfevents.1512798221.darth-Inspiron7559
     文件        2819  2017-12-30 11:44  cnn-svm-master\main.py
     目录           0  2017-12-30 11:44  cnn-svm-master\model\
     文件        8768  2017-12-30 11:44  cnn-svm-master\model\cnn_softmax.py
     文件        9215  2017-12-30 11:44  cnn-svm-master\model\cnn_svm.py
     文件          40  2017-12-30 11:44  cnn-svm-master\requirements.txt
     文件         504  2017-12-30 11:44  cnn-svm-master\setup.sh
     目录           0  2017-12-30 11:44  cnn-svm-master\trained-cnn-softmax\
     文件    39295616  2017-12-30 11:44  cnn-svm-master\trained-cnn-softmax\CNN-Softmax-9900.data-00000-of-00001
     文件         914  2017-12-30 11:44  cnn-svm-master\trained-cnn-softmax\CNN-Softmax-9900.index
     文件       77531  2017-12-30 11:44  cnn-svm-master\trained-cnn-softmax\CNN-Softmax-9900.meta
     文件         183  2017-12-30 11:44  cnn-svm-master\trained-cnn-softmax\checkpoint
     目录           0  2017-12-30 11:44  cnn-svm-master\trained-cnn-svm\
     文件    39295616  2017-12-30 11:44  cnn-svm-master\trained-cnn-svm\CNN-SVM-9900.data-00000-of-00001
     文件         914  2017-12-30 11:44  cnn-svm-master\trained-cnn-svm\CNN-SVM-9900.index
     文件       87201  2017-12-30 11:44  cnn-svm-master\trained-cnn-svm\CNN-SVM-9900.meta
     文件          81  2017-12-30 11:44  cnn-svm-master\trained-cnn-svm\checkpoint

评论

共有 条评论