• 大小: 1KB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2021-03-24
  • 语言: Matlab
  • 标签: 深度学习  MATLAB  DBN  

资源简介

深度学习领域,dbn网络的训练代码,已经证实能够正常使用,用于matlab仿真专用

资源截图

代码片段和文件信息

function rbm = rbmtrain(rbm x opts)
    assert(isfloat(x) ‘x must be a float‘);
    assert(all(x(:)>=0) && all(x(:)<=1) ‘all data in x must be in [0:1]‘);
    m = size(x 1);
    numbatches = m / opts.batchsize;
    
    assert(rem(numbatches 1) == 0 ‘numbatches not integer‘);

    for i = 1 : opts.numepochs
        kk = randperm(m);
        err = 0;
        for l = 1 : numbatches
            batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize) :);
            
            v1 = batch;
            h1 = sigmrnd(repmat(rbm.c‘ opts.batchsize 1) + v1 * rbm.W‘);
            v2 = sigmrnd(repmat(rbm.b‘ opts.batchsize 1) + h1 * rbm.W);
            h2 = sigm(repmat(rbm.c‘ opts.batchsize 1) + v2 * rbm.W‘);

            c1 = h1‘ * v1;
            c2 = h2‘ * v2;

            rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)     / opts.batchsize;
            rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)‘ / opts.batchsize;
            rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)‘ / opts.batchsize;

            rbm.W = rbm.W + rbm.vW;
            rbm.b = rbm.b + rbm.vb;
            rbm.c = rbm.c + rbm.vc;

            err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
        end
        
        disp([‘epoch ‘ num2str(i) ‘/‘ num2str(opts.numepochs)  ‘. Average reconstruction error is: ‘ num2str(err / numbatches)]);
        
    end
end

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     文件        1401  2020-09-26 16:52  rbmtrain.m

评论

共有 条评论