• 大小: 3KB
    文件类型: .rar
    金币: 1
    下载: 0 次
    发布日期: 2021-05-25
  • 语言: 其他
  • 标签: em算法  

资源简介

曾经为了研究em算法,在网上搜寻了一个月的资料,也没有找到em算法的原代码,后来终于在一个资深教授那里找到相关资料,特地传上来和大家共享

资源截图

代码片段和文件信息

function [test_targets param_struct] = EM(train_patterns train_targets test_patterns Ngaussians)

% Classify using the expectation-maximization algorithm
% Inputs:
%  train_patterns - Train patterns
% train_targets - Train targets
%   test_patterns   - Test  patterns
%   Ngaussians      - Number for Gaussians for each class (vector)
%
% Outputs
% test_targets - Predicted targets
%   param_struct    - A parameter structure containing the parameters of the Gaussians found

classes             = unique(train_targets); %Number of classes in targets
Nclasses            = length(classes);
Nalpha = Ngaussians;  %Number of Gaussians in each class
Dim                 = size(train_patterns1);

max_iter    = 100;
max_try             = 5;
Pw = zeros(Nclassesmax(Ngaussians));
sigma = zeros(Nclassesmax(Ngaussians)size(train_patterns1)size(train_patterns1));
mu = zeros(Nclassesmax(Ngaussians)size(train_patterns1));

%The initial guess is based on k-means preprocessing. If it does not converge after
%max_iter iterations a random guess is used.
disp(‘Using k-means for initial guess‘)
for i = 1:Nclasses
    in   = find(train_targets==classes(i));
    [initial_mu targets labels] = k_means(train_patterns(:in)train_targets(:in)Ngaussians(i));
    for j = 1:Ngaussians(i)
        gauss_labels    = find(labels==j);
        Pw(ij)         = length(gauss_labels) / length(labels);
        sigma(ij::)  = diag(std(train_patterns(:in(gauss_labels))‘));
    end
    mu(i1:Ngaussians(i):) = initial_mu‘;
end

%Do the EM: Estimate mean and covariance for each class 
for c = 1:Nclasses
    train    = find(train_targets == classes(c));
    
    if (Ngaussians(c) == 1)
        %If there is only one Gaussian there is no need to do a whole EM procedure
        sigma(c1::)  = sqrtm(cov(train_patterns(:train)‘1));
        mu(c1:)       = mean(train_patterns(:train)‘);
    else
        
        sigma_i         = squeeze(sigma(c:::));
        old_sigma       = zeros(size(sigma_i));  %Used for the stopping criterion
        iter = 0; %Iteration counter
        n   = length(train); %Number of training points
        qi     = zeros(Nalpha(c)n);     %This will hold qi‘s
        P = zeros(1Nalpha(c));
        Ntry            = 0;
        
        while ((sum(sum(sum(abs(sigma_i-old_sigma)))) > 1e-4) & (Ntry < max_try))
            old_sigma = sigma_i;
            
            %E step: Compute Q(theta; theta_i)
            for t = 1:n
                data  = train_patterns(:train(t));
                for k = 1:Nalpha(c)
                    P(k) = Pw(ck) * p_single(data squeeze(mu(ck:)) squeeze(sigma_i(k::)));
                end          
                
                for i = 1:Nalpha(c)
                    qi(it) = P(i) / sum(P);
                end
            end
            
            %M step: theta_i+1 <- argmax(Q(the

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件       5336  2007-12-18 20:47  EM.m

     文件       1834  2003-06-26 21:14  k_means.m

----------- ---------  ---------- -----  ----

                 7170                    2


评论

共有 条评论