function [Am,Bm,sigmam] = gibbsnmf(X,N,M,varargin)
% GIBBSNMF Non-negative matrix factorization Gibbs sampler
%
% Usage
% [Am,Bm,sm] = gibbsnmf(X,N,M,[options])
%
% Input
% X Data matrix (I x J)
% N Number of components
% M Number of Gibbs samples to compute
% options
% alpha Prior for A (I x N)
% beta Prior for B (N x J)
% theta Prior for sigma
% k Prior for sigma
%
% A Initial value for A (I x N)
% B Initial value for B (N x J)
% sigma Initial value for noise variance (sigma^2)
%
% chains Number of chains to run (default 1)
% skip Number of initial samples to skip (default 100)
% stride Return every n'th sample (default 1)
%
% nA Do not sample from these columns of A (N x 1 logical)
% nB Do not sample from these rows of B (N x 1 logical)
% ns Do not sample from sigma (logical)
%
% Output
% Am Samples of A (I x N x M)
% Bm Samples of B (N x J x M)
% sm Samples of sigma (M x 1)
%
% Author
% Mikkel N. Schmidt,
% DTU Informatics, Technical University of Denmark.
% Copyright Mikkel N. Schmidt, ms@it.dk, www.mikkelschmidt.dk
[I,J] = size(X);
opts = mgetopt(varargin);
alpha = updim(mgetopt(opts, 'alpha', zeros(I,N)),[I N]);
beta = updim(mgetopt(opts, 'beta', zeros(N,J)),[N J]);
theta = mgetopt(opts, 'theta', 0);
k = mgetopt(opts, 'k', 0);
A0 = mgetopt(opts, 'A', rand(I,N));
B0 = mgetopt(opts, 'B', rand(N,J));
sigma0 = mgetopt(opts, 'sigma', 1);
chains = mgetopt(opts, 'chains', 1);
stride = mgetopt(opts, 'stride', 1);
skip = mgetopt(opts, 'skip', 0);
nA = mgetopt(opts, 'nA', false(N,1));
nB = mgetopt(opts, 'nB', false(N,1));
ns = mgetopt(opts, 'ns', false);
Am = zeros(I,N,M*chains);
Bm = zeros(N,J,M*chains);
sigmam = zeros(M*chains,1);
x = sum(X(:).^2)/2;
for r = 1:chains
A = A0;
B = B0;
sigma = sigma0;
for m = 1:M
for i = 1:skip*(m==1)+stride*(m>1)
C = B*B';
D = X*B';
for n = 1:N
if ~nA(n)
nn = [1:n-1 n+1:N];
A(:,n) = randr((D(:,n)-A(:,nn)*C(nn,n))/C(n,n), ...
sigma/C(n,n), alpha(:,n));
end
end
if ~ns
sigma = 1/gamrnd((I*J)/2+1+k, ...
1/(x+theta+sum(sum(A.*(A*C-2*D)))/2));
end
E = A'*A;
F = A'*X;
for n = 1:N
if ~nB(n)
nn = [1:n-1 n+1:N];
B(n,:) = randr((F(n,:)-E(n,nn)*B(nn,:))'/E(n,n), ...
sigma/E(n,n), beta(n,:)');
end
end
end
Am(:,:,m+(r-1)*M) = A;
Bm(:,:,m+(r-1)*M) = B;
sigmam(m+(r-1)*M) = sigma;
end
end
%--------------------------------------------------------------------------
function X = updim(x, dim)
% UPDIM Replicate and tile an array
%
% Usage
% X = updim(x, dim)
% Copyright 2007 Mikkel N. Schmidt, ms@it.dk, www.mikkelschmidt.dk
dimx = zeros(size(dim));
for k = 1:length(dim), dimx(k) = size(x,k); end
dim(dim==dimx) = 1;
X = repmat(x,dim);
%--------------------------------------------------------------------------
function x = randr(m, s, l)
% RANDR Random numbers from
% p(x)=K*exp(-(x-m)^2/s-l'x), x>=0
%
% Usage
% x = randr(m,s,l)
% Copyright 2007 Mikkel N. Schmidt, ms@it.dk, www.mikkelschmidt.dk
A = (l.*s-m)./(sqrt(2*s));
a = A>26;
x = zeros(size(m));
y = rand(size(m));
x(a) = -log(y(a))./((l(a).*s-m(a))./s);
R = erfc(abs(A(~a)));
x(~a) = erfcinv(y(~a).*R-(A(~a)<0).*(2*y(~a)+R-2)).*sqrt(2*s)+m(~a)-l(~a).*s;
x(isnan(x)) = 0;
x(x<0) = 0;
x(isinf(x)) = 0;
x = real(x);
%--------------------------------------------------------------------------
function out = mgetopt(varargin)
% MGETOPT Parser for optional arguments
%
% Usage
% Get alpha parameter structure from 'varargin'
% opts = mgetopt(varargin);
%
% Get and parse alpha parameter:
% var = mgetopt(opts, varname, default);
% opts: parameter structure
% varname: name of variable
% default: default value if variable is not set
%
% var = mgetopt(opts, varname, default, command, argument);
% command, argument:
% String in set:
% 'instrset', {'str1', 'str2', ... }
%
% Example
% function y = myfun(x, varargin)
% ...
% opts = mgetopt(varargin);
% parm1 = mgetopt(opts, 'parm1', 0)
% ...
% Copyright 2007 Mikkel N. Schmidt, ms@it.dk, www.mikkelschmidt.dk
if nargin==1
if isempty(varargin{1})
out = struct;
elseif isstruct(varargin{1})
out = varargin{1}{:};
elseif isstruct(varargin{1}{1})
out = varargin{1}{1};
else
out = cell2struct(varargin{1}(2:2:end),varargin{1}(1:2:end),2);
end
elseif nargin>=3
opts = varargin{1};
varname = varargin{2};
default = varargin{3};
validation = varargin(4:end);
if isfield(opts, varname)
out = opts.(varname);
else
out = default;
end
for narg = 1:2:length(validation)
cmd = validation{narg};
arg = validation{narg+1};
switch cmd
case 'instrset',
if ~any(strcmp(arg, out))
fprintf(['Wrong argument %sigma = ''%sigma'' - ', ...
'Using default : %sigma = ''%sigma''\n'], ...
varname, out, varname, default);
out = default;
end
case 'dim'
if ~all(size(out)==arg)
fprintf(['Wrong argument dimension: %sigma - ', ...
'Using default.\n'], ...
varname);
out = default;
end
otherwise,
error('Wrong option: %sigma.', cmd);
end
end
end