• 大小: 66KB
    文件类型: .tar
    金币: 1
    下载: 0 次
    发布日期: 2021-05-18
  • 语言: 其他
  • 标签: MATLAB  深度学习  RBM  

资源简介

实现深度玻尔兹曼机的手写识别,数据库是标准手写识别数据库,自己可以去下载

资源截图

代码片段和文件信息

% Version 1.000
%
% Code provided by Ruslan Salakhutdinov 
%
% Permission is granted for anyone to copy use modify or distribute this
% program and accompanying programs and documents for any purpose provided
% this copyright notice is retained and prominently displayed along with
% a note saying that the original programs are available from our
% web page.
% The programs and documents are distributed without any warranty express or
% implied.  As the programs were written for research purposes only they have
% not been tested to the degree that would be advisable in any important
% application.  All use of these programs is entirely at the user‘s own risk.


test_err=[];
test_crerr=[];
train_err=[];
train_crerr=[];

fprintf(1‘\nTraining discriminative model on MNIST by minimizing cross entropy error. \n‘);
fprintf(1‘60 batches of 1000 cases each. \n‘);

[numcases numdims numbatches]=size(batchdata);
N=numcases; 

load fullmnist_dbm
[numdims numhids] = size(vishid);
[numhids numpens] = size(hidpen); 

%%%%%% Preprocess the data %%%%%%%%%%%%%%%%%%%%%%

[testnumcases testnumdims testnumbatches]=size(testbatchdata);
N=testnumcases;
temp_h2_test = zeros(testnumcasesnumpenstestnumbatches); 
for batch = 1:testnumbatches
   data = [testbatchdata(::batch)];
   [temp_h1 temp_h2] = ...
       mf_class(datavishidhidbiasesvisbiaseshidpenpenbiases);
   temp_h2_test(::batch) = temp_h2;
end  

[numcases numdims numbatches]=size(batchdata);
N=numcases;
temp_h2_train = zeros(numcasesnumpensnumbatches);
for batch = 1:numbatches
   data = [batchdata(::batch)];
   [temp_h1 temp_h2] = ...
     mf_class(datavishidhidbiasesvisbiaseshidpenpenbiases);
   temp_h2_train(::batch) = temp_h2;
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

w1_penhid = hidpen‘;
w1_vishid = vishid;
w2 = hidpen;
h1_biases = hidbiases; h2_biases = penbiases; 

w_class = 0.1*randn(numpens10); 
topbiases = 0.1*randn(110);

for epoch = 1:maxepoch 

%%%% TEST STATS 
%%%% Error rates 
   [testnumcases testnumdims testnumbatches]=size(testbatchdata);
   N=testnumcases;
   bias_hid= repmat(h1_biasesN1);
   bias_pen = repmat(h2_biasesN1);
   bias_top = repmat(topbiasesN1);

   err=0;
   err_cr=0;
   counter=0;  
   for batch = 1:testnumbatches
     data = [testbatchdata(::batch)];
     temp_h2 = temp_h2_test(::batch); 
     target = [testbatchtargets(::batch)]; 

     w1probs = 1./(1 + exp(-data*w1_vishid -temp_h2*w1_penhid - bias_hid  )); 
     w2probs = 1./(1 + exp(-w1probs*w2 - bias_pen)); 
     targetout = exp(w2probs*w_class + bias_top );
     targetout = targetout./repmat(sum(targetout2)110);
     [I J]=max(targetout[]2); 
     [I1 J1]=max(target[]2); 
     counter=counter+length(find(J~=J1));  
     err_cr = err_cr- sum(sum( target(:1:end).*log(targetout))) ;
   end

   test_err(epoch)=counter;
   test_crerr(epoch)=err_cr;
   fprintf(1‘\nepoch %d test  misclassification err %d (out of 10000)  test cross entropy error

评论

共有 条评论