Skip to content

Commit

Permalink
Merge pull request #3 from WTCN-computational-anatomy-group/tissue-we…
Browse files Browse the repository at this point in the history
…ights

Tissue weights
  • Loading branch information
brudfors authored Mar 16, 2022
2 parents 83e5d23 + 50b3d9e commit ad084d9
Show file tree
Hide file tree
Showing 10 changed files with 417 additions and 211 deletions.
6 changes: 3 additions & 3 deletions spm_mb_appearance.m
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
for n=1:numel(dat)
if isfield(dat(n).model,'gmm')
p = dat(n).model.gmm.pop;
same = all(sum(diff(sett.gmm(p).pr{1},1,2).^2,1))==0;
same = ~all(sum(diff(sett.gmm(p).pr{1},1,2).^2,1));
if same
% Intensity priors are identical for all clusters
% so need to break the symmetry.
Expand Down Expand Up @@ -276,7 +276,7 @@
W = sett.gmm(p).pr{3};
for k=1:size(W,3)
S = inv(W(:,:,k));
W(:,:,k) = inv(0.999*S + 0.001*mean(diag(S))*eye(size(S)));
W(:,:,k) = inv(S*(1-1e-9) + 1e-9*mean(diag(S))*eye(size(S)));
end
sett.gmm(p).pr{3} = W;

Expand Down Expand Up @@ -624,7 +624,7 @@ function debug_show(img,img_is,modality,fig_title,do_show)
c = 1; % Channel to show
img = img(:,:,:,c);
elseif strcmp(img_is,'template')
img = spm_mb_classes('template_k1',img,4);
img = spm_mb_classes('template_k1',img,[],4);
end
clim = [-Inf Inf];
if modality == 2
Expand Down
158 changes: 88 additions & 70 deletions spm_mb_classes.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
% P - Updated tissue classes
%
% FORMAT [dat,P] = spm_mb_classes('update_cat',dat,mu,sett)
% FORMAT lab = spm_mb_classes('get_labels',dat,K1)
% FORMAT cm = spm_mb_classes('get_label_conf_matrix',cm_map,w,K1)
% FORMAT l = spm_mb_classes('LSE',mu,ax)
% FORMAT mu = spm_mb_classes('template_k1',mu,ax)
% FORMAT l = spm_mb_classes('LSE0',mu,ax)
% FORMAT l = spm_mb_classes('LSE1',mu,ax)
% FORMAT mu = spm_mb_classes('template_k1',mu,delta)
%__________________________________________________________________________
% Copyright (C) 2019-2020 Wellcome Centre for Human Neuroimaging

Expand All @@ -27,29 +25,86 @@
%==========================================================================
function [P,dat] = get_classes(dat,mu,sett)

% Memory hungry. Needs to be addressed later.
mu = add_delta(mu,dat.delta);

if isfield(dat.model,'cat')
% Categorical model
[dat,P] = update_cat(dat,mu);
elseif isfield(dat.model,'gmm')
% GMM model

% Expand mu to include the background class and combine with labels if
% required.
mu = template_k1(mu);
lab = get_labels(dat,size(mu,4));
if numel(lab)>1
% Add labels to template
% mu = mu + lab;
end
clear lab
% Expand mu to include the background class.
mu1 = template_k1(mu);
if sett.gmm(dat.model.gmm.pop).nit_appear >0
[dat,P] = spm_mb_appearance('update',dat,mu,sett);
[dat,P] = spm_mb_appearance('update',dat,mu1,sett);
else
P = exp(bsxfun(@minus,mu(:,:,:,1:(size(mu,4)-1)),LSE1(mu,4)));
P = exp(bsxfun(@minus,mu1(:,:,:,1:(size(mu1,4)-1)),LSE1(mu1,4)));
end
else
error('This should not happen');
end
if ~isempty(dat.delta)
[dat.delta,tmp] = update_delta(dat.delta,mu,P,sett.del_settings,sett.accel);
dat.E(1) = dat.E(1)+tmp;
end
%==========================================================================

%==========================================================================
function [delta,dE] = update_delta(delta,mu,P,del_settings,accel)
% disp(exp(delta)/sum(exp(delta)))
K = size(mu,4);
L = (eye(K)-1/(K+1))*del_settings;
H = L;
g = L*delta(:);
for k=1:size(mu,3)
[g1,H1] = gradhess1(mu(:,:,k,:),P(:,:,k,:),delta,accel);
g = g + double(reshape(sum(sum(g1,1),2),[K 1]));
H = H + double(reshape(sum(sum(H1,1),2),[K K]));
end
dE = 0.5*delta(:)'*L*delta(:);
delta(:) = delta(:) - H\g;
%==========================================================================

%==========================================================================
function [g,H] = gradhess1(mu,P,delta,accel)
dm = size(mu);
K = size(mu,4);
Ab = 0.5*(eye(K)-1/(K+1)); % Bohnings bound on the Hessian
if nargin>=3 && ~isempty(delta)
delta = reshape(delta,[1 1 1 K]);
mu = bsxfun(@plus,mu,delta);
end
H = zeros([dm(1:3),K,K]);
g = zeros([dm(1:3),K,1]);
sig = softmax0(mu);
msk = ~(all(isfinite(sig),4) & all(isfinite(P),4));
for k=1:K
sig_k = sig(:,:,:,k);
tmp = sig_k - P(:,:,:,k);
tmp(msk) = 0;
g(:,:,:,k) = tmp;
tmp = (sig_k - sig_k.^2)*accel + (1-accel)*Ab(k,k);
tmp(msk) = 0;
H(:,:,:,k,k) = tmp;
for k1=(k+1):K
tmp = (-sig_k.*sig(:,:,:,k1))*accel + (1-accel)*Ab(k,k1);
tmp(msk) = 0;
H(:,:,:,k,k1) = tmp;
H(:,:,:,k1,k) = tmp;
end
end
%==========================================================================

%==========================================================================
function P = softmax0(mu,ax)
% safe softmax function (matches LSE0)

if nargin<2, ax = 4; end
mx = max(mu,[],ax);
E = exp(bsxfun(@minus,mu,mx));
den = sum(E,ax)+exp(-mx);
P = bsxfun(@rdivide,E,den);
%==========================================================================

%==========================================================================
Expand All @@ -62,73 +117,40 @@
% Compute subject-specific categorical cross-entropy loss between
% segmentation and template
msk = all(isfinite(P),4) & all(isfinite(mu),4);
tmp = sum(P.*mu,4) - LSE(mu,4);
tmp = sum(P.*mu,4) - LSE0(mu,4);
dat.E(1) = -sum(tmp(msk(:)));
dat.nvox = sum(msk(:));
%==========================================================================

%==========================================================================
function lab = get_labels(dat, K1)
lab=0; return;
if isempty(dat.lab), lab = 0; return; end

% Load labels
lab = spm_mb_io('get_data', dat.lab.f);
sk = dat.samp;
lab = lab(1:sk(1):end, 1:sk(2):end, 1:sk(3):end);
dm = [size(lab) 1 1];
lab = round(lab(:));
cm_map = dat.lab.cm_map; % cell array that defines the confusion matrix
lab(~isfinite(lab) | lab<1 | lab>numel(cm_map)) = numel(cm_map) + 1; % Prevent crash

% Get confusion matrix that maps from label value to (log) probability value
cm = get_label_conf_matrix(cm_map, dat.lab.w, K1);

% Build "one-hot" representation using confusion matrix
lab = reshape(cm(lab,:), [dm(1:3) K1]);
%==========================================================================

%==========================================================================
function cm = get_label_conf_matrix(cm_map, w, K1)
% FORMAT CM = get_label_conf_matrix(cm_map, w, K1)
% cm_map - Defines the confusion matrix
% w - Weighting probability
% K1 - Number of classes
% cm - confusion matrix
%
% Build Rater confusion matrix for one subject.
% This matrix maps template classes to manually segmented classes.
% Manual labels often do not follow the same convention as the Template,
% and not all regions may be labelled. Therefore, a manual label may
% correspond to several Template classes and, conversely, one Template
% class may correspond to several manual labels.

% Parse function settings
w = min(max(w, 1e-7), 1-1e-7);
L = numel(cm_map); % Number of labels
cm = zeros([L+1, K1], 'single'); % Allocate confusion matrix (including unknown)
for l=1:L % Loop over labels
ix = false(1, K1);
ix(cm_map{l}) = true;
cm(l, ix) = log(w/sum( ix));
cm(l,~ix) = log((1-w)/sum(~ix));
function mu1 = add_delta(mu,delta)
if isempty(delta)
mu1 = mu;
else
mu1 = bsxfun(@plus,mu,reshape(delta,[1 1 1 size(mu,4)]));
end
%==========================================================================

%==========================================================================
function l = LSE(mu,ax)
% log-sum-exp function
function l = LSE0(mu,ax)
% Strictly convex log-sum-exp function
% https://en.wikipedia.org/wiki/LogSumExp#A_strictly_convex_log-sum-exp_type_function
if nargin<2, ax = 4; end
mx = max(max(mu,[],ax),0);
l = log(exp(-mx) + sum(exp(bsxfun(@minus,mu,mx)),ax)) + mx;
%==========================================================================

%==========================================================================
function mu = template_k1(mu,ax)
function mu1 = template_k1(mu,delta,ax)
% Expand a template to include the implicit background class
if nargin<2, ax = 4; end
lse = LSE(mu,ax);
mu = cat(ax,bsxfun(@minus,mu,lse), -lse);
if nargin>=2
mu1 = add_delta(mu,delta);
else
mu1 = mu;
end
if nargin<3,ax=4; end
lse = LSE0(mu1,ax);
mu1 = cat(ax,bsxfun(@minus,mu1,lse), -lse);
%==========================================================================

%==========================================================================
Expand All @@ -141,7 +163,3 @@

%==========================================================================

%==========================================================================

%==========================================================================

36 changes: 6 additions & 30 deletions spm_mb_fit.m
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
%__________________________________________________________________________
% Copyright (C) 2020 Wellcome Centre for Human Neuroimaging

% $Id: spm_mb_fit.m 8196 2021-12-16 15:18:25Z john $
% $Id: spm_mb_fit.m 8226 2022-02-24 10:44:46Z john $


% Repeatable random numbers
Expand Down Expand Up @@ -74,7 +74,7 @@

% Update affine only
%--------------------------------------------------------------------------
fprintf('Rigid (zoom=%d): %d x %d x %d\n',2^(numel(sz)-1),sett.ms.d);
fprintf('Rigid (zoom=1/%d): %d x %d x %d\n',2^(numel(sz)-1),sett.ms.d);
spm_plot_convergence('Init','Rigid Alignment','Objective','Iteration');
E = Inf;
for it0=1:nit_aff
Expand Down Expand Up @@ -116,38 +116,14 @@
countdown = 6;
end
end

% Finish affine registration of any subjects that need a few more iterations
for it0=1:3

if updt_mu
[mu,sett,dat,te,E] = iterate_mean(mu,sett,dat,te,nit_mu);
end

for n=1:numel(dat)
En = Inf;
for it1=1:nit_aff
oEn = En;
dat(n) = spm_mb_shape('update_simple_affines',dat(n),mu,sett);
En = sum(dat(n).E)/nvox(dat(n));
if abs(oEn-En) < sett.tol*0.2, break; end
end
end

E = sum(sum(cat(2,dat.E),2),1) + te; % Cost function after previous update
fprintf('%8.4f\n', E/nvox(dat));
spm_plot_convergence('Set',E/nvox(dat));
do_save(mu,sett,dat);
end

spm_plot_convergence('Clear');
nit_mu = 2;
nit_mu = 4;

% Update affine and diffeo (iteratively decreases the template resolution)
%--------------------------------------------------------------------------
spm_plot_convergence('Init','Diffeomorphic Alignment','Objective','Iteration');
for zm=numel(sz):-1:1 % loop over zoom levels
fprintf('\nzoom=%d: %d x %d x %d\n', 2^(zm-1), sett.ms.d);
fprintf('\nzoom=1/%d: %d x %d x %d\n', 2^(zm-1), sett.ms.d);

if updt_mu
dat = spm_mb_appearance('restart',dat,sett);
Expand All @@ -163,7 +139,7 @@
if ~updt_mu
mu = spm_mb_shape('shrink_template',mu0,Mmu,sett);
else
[mu,sett,dat,te,E] = iterate_mean(mu,sett,dat,te,nit_mu);
[mu,sett,dat,te] = iterate_mean(mu,sett,dat,te,nit_mu);
end

if updt_aff
Expand All @@ -176,7 +152,7 @@
end
fprintf('\n');

nit_max = nit_zm0 + (zm - 1)*3;
nit_max = nit_zm0 + (zm - 1)*2;
for it0=1:nit_max

oE = E/nvox(dat);
Expand Down
7 changes: 4 additions & 3 deletions spm_mb_gmm.m
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@
muo = mu(io,k);
mum = mu(im,k);

iAkmm = inv(Ak(im,im));
iAkmm = Ak(im,im);
iAkmm = iAkmm + eye(size(iAkmm))*(1e-6*max(diag(iAkmm))+1e-30);
iAkmm = inv(iAkmm);
SA = iAkmm*Ak(im,io);

% 1) observed
Expand Down Expand Up @@ -830,7 +832,6 @@
eta = N + eta0; % Number of observations

% Starting estimates (working with log(alpha0))
la0 = log(a0)*ones(K,1);
la0 = log(mean(Alpha,2));
alpha0 = exp(la0);
for it=1:100
Expand All @@ -848,7 +849,7 @@
la0 = la0 - H\g;
alpha0 = exp(la0);

if norm(la0-la0o).^2 <= norm(la0).^2*1e-9, break; end
if norm(la0-la0o)^2 <= norm(la0)^2*1e-9, break; end
end
alpha0 = alpha0 + eps;
lb = [(-eta0*(sum(gammaln(alpha0)) - gammaln(sum(alpha0))) - v0'*alpha0)
Expand Down
17 changes: 14 additions & 3 deletions spm_mb_init.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
% Copyright (C) 2018-2020 Wellcome Centre for Human Neuroimaging


% $Id: spm_mb_init.m 8086 2021-04-01 09:13:20Z john $
% $Id: spm_mb_init.m 8220 2022-02-09 12:21:22Z john $

[dat,sett] = mb_init1(cfg);

Expand Down Expand Up @@ -43,7 +43,7 @@
if ~isempty(sett.aff)
B = spm_mb_shape('affine_bases',sett.aff);
else
B = zeros([3 3 0]);
B = zeros([4 4 0]);
end
sett.B = B;
sett = rmfield(sett,'aff');
Expand Down Expand Up @@ -77,7 +77,7 @@

cl = cell(N,1);
dat = struct('dm',cl, 'Mat',cl, 'samp',[1 1 1], 'onam','', 'odir','',...
'q',cl, 'v',cl, 'psi',cl, 'model',cl, 'lab',cl, 'E',cl,'nvox',cl);
'q',cl, 'v',cl, 'delta', zeros(1,K), 'psi',cl, 'model',cl, 'lab',cl, 'E',cl,'nvox',cl);
n = 0;

% Process categorical data
Expand Down Expand Up @@ -113,6 +113,12 @@
dat(n).odir = sett.odir;
dat(n).v = fullfile(dat(n).odir,['v_' dat(n).onam '.nii']);
dat(n).psi = fullfile(dat(n).odir,['y_' dat(n).onam '.nii']);
if isfinite(sett.del_settings)
dat(n).delta = zeros(1,K);
else
dat(n).delta = [];
end


Kn = 0;
for c=1:Nc
Expand Down Expand Up @@ -212,6 +218,11 @@
dat(n).v = fullfile(dat(n).odir,['v_' dat(n).onam '.nii']);
dat(n).psi = fullfile(dat(n).odir,['y_' dat(n).onam '.nii']);

if isfinite(sett.del_settings)
dat(n).delta = zeros(1,K);
else
dat(n).delta = [];
end
cf = zeros(Nc,1);
for c=1:Nc
cf(c) = size(f(1).dat,4);
Expand Down
Loading

0 comments on commit ad084d9

Please sign in to comment.