资源简介

对matconvnet中用于cnn训练的cnn_train.m文件进行了详细的注释。后续会进行更新,并且会把相关的函数也进行注释

资源截图

代码片段和文件信息

%调用cnn_train:  
% [ net info ] = cnn_train(net imdb @getBatch opts.train ‘val‘ find(imdb.images.set == 3)) ;  
%%
%主体函数cnn_train
function [net stats] = cnn_train(net imdb getBatch varargin)
%CNN_TRAIN  An example implementation of SGD for training CNNs
%    CNN_TRAIN() is an example learner implementing stochastic
%    gradient descent with momentum to train a CNN. It can be used
%    with different datasets and tasks by providing a suitable
%    getBatch function.
%       带有动量功能的SGD实现,可以根据getBatch函数不同用于不同的数据集和任务
%    The function automatically restarts after each training epoch by
%    checkpointing.
%       每次epoch之后重启
%    The function supports training on CPU or on one or more GPUs
%    (specify the list of GPU IDs in the ‘gpus‘ option).

% Copyright (C) 2014-16 Andrea Vedaldi.
% All rights reserved.
%
% This file is part of the VLFeat library and is made available under根据
% the terms of the BSD license (see the COPYING file).
addpath(fullfile(vl_rootnn ‘examples‘));%添加examples的路径

opts.expDir = fullfile(‘data‘‘exp‘) ;%选择保存路径
opts.continue = true ;%每次重启都是接着上次训练状态开始
opts.batchSize = 256 ;%选择初始化批大小为256
opts.numSubBatches = 1 ;%选择子批的个数为1(不划分子批)
opts.train = [] ;%初始化训练集索引为空
opts.val = [] ;%初始化验证集索引为空
opts.gpus = [] ;%选择gpu
opts.epochSize = inf;%inf无穷大量,
opts.prefetch = false ;%选择是否预读取下一批次的样本,初始为否
opts.numEpochs = 300 ;%选择epoch数量
opts.learningRate = 0.001 ;
opts.weightDecay = 0.0005 ;

opts.solver = [] ;  % Empty array means use the default SGD solver使用默认的SGDsolver训练
[opts varargin] = vl_argparse(opts varargin) ;%调用函数修改默认参数配置
%对结构体opts中的内容,用varargin进行更新,opts中没有的元素复制到varargin中
if ~isempty(opts.solver)%如果opts.solver不是空集
  assert(isa(opts.solver ‘function_handle‘) && nargout(opts.solver) == 2...% isa判断输入内容是否为指定类的对象,是的话返回true
    ‘Invalid solver; expected a function handle with two outputs.‘) ;%assert如果cond是false则引发错误并且返回信息。
%如果不为空,则当opts.solver是函数句柄并且输出的参数数目为2时才能继续
  % Call without input arguments to get default options
  opts.solverOpts = opts.solver() ;
end

opts.momentum = 0.9 ;
opts.saveSolverState = true ;
opts.nesterovUpdate = false ;
opts.randomSeed = 0 ;
opts.memoryMapFile = fullfile(tempdir ‘matconvnet.bin‘) ;%tempdir系统的缓存目录。选择内存映射文件
opts.profile = false ;%用于观察每句程序的耗时
opts.parameterServer.method = ‘mmap‘ ;
opts.parameterServer.prefix = ‘mcn‘ ;%词头

opts.conserveMemory = true ;%保存内存
opts.backPropDepth = +inf ;%bp算法的深度
opts.sync = false ;%同步
opts.cudnn = true ;
opts.errorFunction = ‘multiclass‘ ;%多类误差函数
opts.errorLabels = {} ;%初始化错误标签为空,误差的类别,如top1error
opts.plotDiagnostics = false ;%是否绘制诊断信息
opts.plotStatistics = true;%是否绘制过程统计信息
opts.postEpochFn = [] ;  % postEpochFn(netparamsstate) called after each epoch; can return a new learning rate 0 to stop [] for no change
%每次之后可以更换学习速率
opts = vl_argparse(opts varargin) ;%调用函数修改默认参数配置
%%
%初始化准备工作
if ~exist(opts.expDir ‘dir‘) mkdir(opts.expDir) ; end%如果不存在保存路径,就创建它
if isempty(opts.train) opts.train = find(i

评论

共有 条评论