-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Phil Tooley
committed
May 7, 2019
1 parent
536773e
commit b7c0202
Showing
6 changed files
with
391 additions
and
255 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,124 +1,185 @@ | ||
#!/usr/bin/env python3 | ||
|
||
""" Mathematical analysis functions for image and map comparison | ||
""" | ||
|
||
from collections import namedtuple | ||
from textwrap import wrap | ||
import os | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import scipy.stats as sps | ||
|
||
from tabulate import tabulate | ||
|
||
import flannel.io as fio | ||
|
||
from .image_routines import (calculate_mutual_information, | ||
load_image, | ||
load_pfire_map) | ||
from .image_routines import load_image, load_pfire_map | ||
|
||
MIResult = namedtuple("mi_result", ['mi', 'hist']) | ||
|
||
def calculate_entropy(prob_dist): | ||
r""" | ||
Calculate Shannon entropy of the provided probability distribution | ||
Shannon Entropy is defined as | ||
$H(X) = \sum_n p_X(x)\log_2{p_X(x)}$ | ||
""" | ||
# First disregard all values where p_X == 0 to avoid nans from log(p_X) | ||
normed_prob_dist = prob_dist[prob_dist > 0] | ||
normed_prob_dist /= normed_prob_dist.sum() | ||
entropy = -np.sum(normed_prob_dist * np.log2(normed_prob_dist)) | ||
|
||
return entropy | ||
|
||
|
||
def calculate_mutual_information(data1, data2, resolution=50, | ||
return_hist=False): | ||
r""" | ||
Calculate mutual information using Shannon entropy of provided data. | ||
Mutual Information is defined as: | ||
MI(X, Y) = H(X) + H(Y) - H(X,Y) | ||
Where H(X), H(Y) and H(X,Y) are the Shannon entropies of the probabilities | ||
and the joint probability of the data. | ||
N.B it is assumed that the two datasets are independent. | ||
Returns a tuple of MI(X,Y), H(X), H(Y), H(X,Y) | ||
""" | ||
jointmax = max(data1.max(), data2.max()) | ||
# First calculate probability density | ||
bin_edges = np.linspace(0, 1, num=resolution) | ||
prob1_2, _, _ = np.histogram2d(data1.flatten()/jointmax, | ||
data2.flatten()/jointmax, | ||
bins=bin_edges, density=True) | ||
prob1 = np.sum(prob1_2, axis=1) | ||
prob2 = np.sum(prob1_2, axis=0) | ||
|
||
entropy1 = calculate_entropy(prob1) | ||
entropy2 = calculate_entropy(prob2) | ||
entropy1_2 = calculate_entropy(prob1_2) | ||
|
||
mutual_information = entropy1 + entropy2 - entropy1_2 | ||
|
||
if return_hist: | ||
return (mutual_information, entropy1, entropy2, entropy1_2, prob1_2) | ||
else: | ||
return (mutual_information, entropy1, entropy2, entropy1_2) | ||
|
||
|
||
def plot_2dhist(data, path, title): | ||
""" Helper function to plot 2d histogram and return rst inclusion command. | ||
""" | ||
plt.matshow(data, origin='lower', cmap='gray') | ||
plt.title("\n".join(wrap(title, 40))) | ||
plt.savefig(path) | ||
plt.close() | ||
|
||
return ".. image:: {}\n".format(path) | ||
|
||
def compare_image_results(fixed_path, moved_path, shirt_registered_path, | ||
pfire_registered_path, save_figs=None): | ||
|
||
def calculate_proficiency(alpha, beta): | ||
""" Calculate proficiency (normalized mutual information) of an image pair | ||
""" | ||
alpha_data = load_image(alpha) | ||
beta_data = load_image(beta) | ||
|
||
res = calculate_mutual_information(alpha_data, beta_data, return_hist=True) | ||
|
||
prof = res[0]/min(res[1], res[2]) | ||
|
||
return MIResult(prof, res[-1]) | ||
|
||
|
||
def compare_image_results(fixed_path, moved_path, accepted_path, | ||
pfire_path, fig_dir=None, cmpname="accepted"): | ||
"""Compare ShIRT and pFIRE registered images | ||
""" | ||
if fig_dir: | ||
os.makedirs(os.path.normpath(fig_dir), mode=0o755, exist_ok=True) | ||
|
||
mi_start = calculate_proficiency(fixed_path, moved_path) | ||
mi_accepted = calculate_proficiency(fixed_path, accepted_path) | ||
mi_pfire = calculate_proficiency(fixed_path, pfire_path) | ||
mi_comparison = calculate_proficiency(accepted_path, pfire_path) | ||
|
||
res_table = [("Normalized mutual information (proficiency):", ""), | ||
("Fixed vs. Moved:", "{:.3f}".format(mi_start.mi)), | ||
("{} vs. Fixed:".format(cmpname), "{:.3f}".format(mi_accepted.mi)), | ||
("pFIRE vs. Fixed:", "{:.3f}".format(mi_pfire.mi)), | ||
("pFIRE vs. {}:".format(cmpname), "{:.3f}\n".format(mi_comparison.mi))] | ||
|
||
print(tabulate(res_table, headers="firstrow", tablefmt='grid') + "\n") | ||
|
||
rst_output = [] | ||
rst_output.append(tabulate(res_table, headers="firstrow", tablefmt="rst")) | ||
rst_output.append("") # table must be followed by blank line | ||
|
||
fixed_data = load_image(fixed_path) | ||
moved_data = load_image(moved_path) | ||
shirt_data = load_image(shirt_registered_path) | ||
pfire_data = load_image(pfire_registered_path) | ||
|
||
mi_start = calculate_mutual_information(fixed_data, moved_data, | ||
return_hist=True) | ||
mi_shirt = calculate_mutual_information(fixed_data, shirt_data, | ||
return_hist=True) | ||
mi_pfire = calculate_mutual_information(fixed_data, pfire_data, | ||
return_hist=True) | ||
mi_comparison = calculate_mutual_information(shirt_data, pfire_data, | ||
return_hist=True) | ||
|
||
prof_start = mi_start[0]/mi_start[1] | ||
prof_shirt = mi_shirt[0]/min(mi_shirt[1], mi_shirt[2]) | ||
prof_pfire = mi_pfire[0]/min(mi_pfire[1], mi_pfire[2]) | ||
rdcy_norm = mi_comparison[0]/min(mi_comparison[1], mi_comparison[2]) | ||
|
||
|
||
print("Normalized mutual information (proficiency):") | ||
print("Pre registration: {:.3f}".format(prof_start)) | ||
print("Shirt registration: {:.3f}".format(prof_shirt)) | ||
print("pFIRE registration: {:.3f}".format(prof_pfire)) | ||
print("Similarity between results (normalized redundancy):") | ||
print("ShIRT vs. pFIRE: {:.3f}\n".format(rdcy_norm)) | ||
|
||
if save_figs: | ||
plt.matshow(mi_start[-1], origin='lower') | ||
plt.title("\n".join(wrap("Pre-registration normalized mutual information: " | ||
"{:0.3f}".format(prof_start), 40))) | ||
plt.savefig("{}_prereg.png".format(os.path.splitext(save_figs)[0])) | ||
plt.close() | ||
|
||
plt.matshow(mi_shirt[-1], origin='lower') | ||
plt.title("\n".join(wrap("ShIRT registration normalized mutual information: " | ||
"{:0.3f}".format(prof_shirt), 40))) | ||
plt.savefig("{}_shirt.png".format(os.path.splitext(save_figs)[0])) | ||
plt.close() | ||
|
||
plt.matshow(mi_pfire[-1], origin='lower') | ||
plt.title("\n".join(wrap("pFIRE registration normalized mutual information: " | ||
"{:0.3f}".format(prof_pfire), 40))) | ||
plt.savefig("{}_pfire.png".format(os.path.splitext(save_figs)[0])) | ||
plt.close() | ||
|
||
plt.matshow(mi_comparison[-1], origin='lower') | ||
plt.title("\n".join(wrap("ShIRT - pFIRE comparison (normalized redundancy): " | ||
"{:0.3f}".format(rdcy_norm), 40))) | ||
plt.savefig("{}_comparison.png".format(os.path.splitext(save_figs)[0])) | ||
plt.close() | ||
|
||
mi_data = (prof_start, prof_shirt, prof_pfire, rdcy_norm) | ||
|
||
return mi_data | ||
|
||
|
||
def compare_map_results(shirt_map_path, pfire_map_path, save_figs=None): | ||
image_rst = [] | ||
|
||
if fig_dir: | ||
image_rst.append(plot_2dhist( | ||
mi_start.hist, os.path.join(fig_dir, "prereg.png"), | ||
"Fixed vs. Moved normalized mutual information: " | ||
"{:0.3f}".format(mi_start.mi))) | ||
|
||
image_rst.append(plot_2dhist( | ||
mi_accepted.hist, os.path.join(fig_dir, "accepted.png"), | ||
"{} vs. Fixed normalized mutual information: " | ||
"{:0.3f}".format(cmpname, mi_accepted.mi))) | ||
|
||
image_rst.append(plot_2dhist( | ||
mi_pfire.hist, os.path.join(fig_dir, "pfire.png"), | ||
"pFIRE vs Fixed normalized mutual information: " | ||
"{:0.3f}".format(mi_pfire.mi))) | ||
|
||
image_rst.append(plot_2dhist( | ||
mi_comparison.hist, os.path.join(fig_dir, "comparison.png"), | ||
"pFIRE vs. {} normalized mutual information: " | ||
"{:0.3f}".format(cmpname, mi_comparison.mi))) | ||
|
||
return ("\n".join(rst_output), "\n".join(image_rst)) | ||
|
||
|
||
def compare_map_results(cmp_map_path, pfire_map_path, fig_dir=None, | ||
cmpname='Accepted'): | ||
"""Compare ShIRT and pFIRE displacement maps | ||
""" | ||
if fig_dir: | ||
os.makedirs(os.path.normpath(fig_dir), mode=0o755, exist_ok=True) | ||
|
||
shirt_map = fio.load_map(shirt_map_path)[0][0:3] | ||
cmp_map = fio.load_map(cmp_map_path)[0][0:3] | ||
pfire_map = load_pfire_map(pfire_map_path) | ||
|
||
print("Map coefficients of determination (R^2), per dimension:") | ||
table_entries = [("Map coefficients of determination (R^2), by dimension:", "")] | ||
image_entries = [] | ||
|
||
corr_data = [] | ||
|
||
plt.plot(shirt_map[0].flatten(), marker='x', ls='none', label="ShIRT") | ||
plt.plot(pfire_map[0].flatten(), marker='+', ls='none', label="pFIRE") | ||
corr_data.append(sps.linregress(shirt_map[0].flatten(), | ||
pfire_map[0].flatten())[2]) | ||
print("X: {:0.3}".format(corr_data[-1]**2)) | ||
plt.title("Map X component, R^2={:0.3}".format(corr_data[-1]**2)) | ||
plt.legend() | ||
if save_figs: | ||
plt.savefig("{}_map_x.png".format(os.path.splitext(save_figs)[0])) | ||
plt.close() | ||
for didx, dim in enumerate(['X', 'Y', 'Z']): | ||
try: | ||
corr = sps.linregress(cmp_map[didx].flatten(), | ||
pfire_map[didx].flatten())[2] | ||
table_entries.append(("{}:".format(dim), "{:0.3}".format(corr**2))) | ||
|
||
plt.plot(shirt_map[1].flatten(), marker='x', ls='none', label="ShIRT") | ||
plt.plot(pfire_map[1].flatten(), marker='+', ls='none', label="pFIRE") | ||
corr_data.append(sps.linregress(shirt_map[1].flatten(), | ||
pfire_map[1].flatten())[2]) | ||
plt.title("Map Y component, R^2={:0.3}".format(corr_data[-1]**2)) | ||
print("Y: {:0.3}".format(corr_data[-1]**2)) | ||
plt.legend() | ||
if save_figs: | ||
plt.savefig("{}_map_y.png".format(os.path.splitext(save_figs)[0])) | ||
plt.close() | ||
if fig_dir: | ||
savepath = os.path.join(fig_dir, "map_{}.png".format(dim.lower())) | ||
plt.plot(cmp_map[didx].flatten(), marker='x', ls='none', | ||
label=cmpname) | ||
plt.plot(pfire_map[didx].flatten(), marker='+', ls='none', | ||
label="pFIRE") | ||
plt.title("Map {} component, R^2={:0.3}".format(dim, corr**2)) | ||
plt.legend() | ||
plt.savefig(savepath) | ||
plt.close() | ||
image_entries.append(".. image:: {}".format(savepath)) | ||
except IndexError: | ||
break | ||
|
||
print(tabulate(table_entries, headers="firstrow", tablefmt="grid")) | ||
|
||
table = tabulate(table_entries, headers="firstrow", tablefmt="rst") | ||
|
||
try: | ||
plt.plot(shirt_map[2].flatten(), marker='x', ls='none', label="ShIRT") | ||
plt.plot(pfire_map[2].flatten(), marker='+', ls='none', label="pFIRE") | ||
corr_data.append(sps.linregress(shirt_map[2].flatten(), | ||
pfire_map[2].flatten())[0]) | ||
plt.title("Map Z component, R^2={:0.3}".format(corr_data[2]**2)) | ||
print("Z: {:0.3}".format(corr_data[2]**2)) | ||
plt.legend() | ||
if save_figs: | ||
plt.savefig("{}_map_z.png".format(os.path.splitext(save_figs)[0])) | ||
plt.close() | ||
except IndexError: | ||
pass | ||
|
||
return corr_data | ||
return (table, "\n".join(image_entries)) |
Oops, something went wrong.