Skip to content

Commit

Permalink
Merge pull request #41 from INSIGNEO/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
ptooley authored May 8, 2019
2 parents 6dc55c9 + 45201e4 commit f4ac8c9
Show file tree
Hide file tree
Showing 9 changed files with 696 additions and 270 deletions.
37 changes: 37 additions & 0 deletions benchmarking/pfire_benchmarking/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python3

""" Testsuite entry point
"""

import argparse
import os

from .testdespatcher import TestDespatcher

def _parse_args():
parser = argparse.ArgumentParser(description="Run pFIRE integration tests")

parser.add_argument('dir', nargs='?', default=os.getcwd(), metavar="datadir",
help="Test data directory, will be inspected recursively")
parser.add_argument('--output', '-o', default='.', metavar="outdir",
help="Path at which to output results")

return parser.parse_args()


def main():
""" Run the testsuite over a directory tree
"""
args = _parse_args()

testsuite = TestDespatcher(output_dir=args.output)

testsuite.find_tests(args.dir)

testsuite.run_tests()

testsuite.create_aggregate_report()


if __name__ == "__main__":
main()
269 changes: 165 additions & 104 deletions benchmarking/pfire_benchmarking/analysis_routines.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,185 @@
#!/usr/bin/env python3

""" Mathematical analysis functions for image and map comparison
"""

from collections import namedtuple
from textwrap import wrap
import os

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sps

import flannel.io as fio
from tabulate import tabulate

from .image_routines import (calculate_mutual_information,
load_image,
load_pfire_map)
from .image_routines import load_image, load_map

def compare_image_results(fixed_path, moved_path, shirt_registered_path,
pfire_registered_path, save_figs=None):
"""Compare ShIRT and pFIRE registered images
MIResult = namedtuple("mi_result", ['mi', 'hist'])

def calculate_entropy(prob_dist):
r"""
Calculate Shannon entropy of the provided probability distribution
Shannon Entropy is defined as
$H(X) = \sum_n p_X(x)\log_2{p_X(x)}$
"""
# First disregard all values where p_X == 0 to avoid nans from log(p_X)
normed_prob_dist = prob_dist[prob_dist > 0]
normed_prob_dist /= normed_prob_dist.sum()
entropy = -np.sum(normed_prob_dist * np.log2(normed_prob_dist))

fixed_data = load_image(fixed_path)
moved_data = load_image(moved_path)
shirt_data = load_image(shirt_registered_path)
pfire_data = load_image(pfire_registered_path)

mi_start = calculate_mutual_information(fixed_data, moved_data,
return_hist=True)
mi_shirt = calculate_mutual_information(fixed_data, shirt_data,
return_hist=True)
mi_pfire = calculate_mutual_information(fixed_data, pfire_data,
return_hist=True)
mi_comparison = calculate_mutual_information(shirt_data, pfire_data,
return_hist=True)

prof_start = mi_start[0]/mi_start[1]
prof_shirt = mi_shirt[0]/min(mi_shirt[1], mi_shirt[2])
prof_pfire = mi_pfire[0]/min(mi_pfire[1], mi_pfire[2])
rdcy_norm = mi_comparison[0]/min(mi_comparison[1], mi_comparison[2])


print("Normalized mutual information (proficiency):")
print("Pre registration: {:.3f}".format(prof_start))
print("Shirt registration: {:.3f}".format(prof_shirt))
print("pFIRE registration: {:.3f}".format(prof_pfire))
print("Similarity between results (normalized redundancy):")
print("ShIRT vs. pFIRE: {:.3f}\n".format(rdcy_norm))

if save_figs:
plt.matshow(mi_start[-1], origin='lower')
plt.title("\n".join(wrap("Pre-registration normalized mutual information: "
"{:0.3f}".format(prof_start), 40)))
plt.savefig("{}_prereg.png".format(os.path.splitext(save_figs)[0]))
plt.close()

plt.matshow(mi_shirt[-1], origin='lower')
plt.title("\n".join(wrap("ShIRT registration normalized mutual information: "
"{:0.3f}".format(prof_shirt), 40)))
plt.savefig("{}_shirt.png".format(os.path.splitext(save_figs)[0]))
plt.close()

plt.matshow(mi_pfire[-1], origin='lower')
plt.title("\n".join(wrap("pFIRE registration normalized mutual information: "
"{:0.3f}".format(prof_pfire), 40)))
plt.savefig("{}_pfire.png".format(os.path.splitext(save_figs)[0]))
plt.close()

plt.matshow(mi_comparison[-1], origin='lower')
plt.title("\n".join(wrap("ShIRT - pFIRE comparison (normalized redundancy): "
"{:0.3f}".format(rdcy_norm), 40)))
plt.savefig("{}_comparison.png".format(os.path.splitext(save_figs)[0]))
plt.close()

mi_data = (prof_start, prof_shirt, prof_pfire, rdcy_norm)

return mi_data


def compare_map_results(shirt_map_path, pfire_map_path, save_figs=None):
"""Compare ShIRT and pFIRE displacement maps
return entropy


def calculate_mutual_information(data1, data2, resolution=50,
return_hist=False):
r"""
Calculate mutual information using Shannon entropy of provided data.
Mutual Information is defined as:
MI(X, Y) = H(X) + H(Y) - H(X,Y)
Where H(X), H(Y) and H(X,Y) are the Shannon entropies of the probabilities
and the joint probability of the data.
N.B it is assumed that the two datasets are independent.
Returns a tuple of MI(X,Y), H(X), H(Y), H(X,Y)
"""
jointmax = max(data1.max(), data2.max())
# First calculate probability density
bin_edges = np.linspace(0, 1, num=resolution)
prob1_2, _, _ = np.histogram2d(data1.flatten()/jointmax,
data2.flatten()/jointmax,
bins=bin_edges, density=True)
prob1 = np.sum(prob1_2, axis=1)
prob2 = np.sum(prob1_2, axis=0)

shirt_map = fio.load_map(shirt_map_path)[0][0:3]
pfire_map = load_pfire_map(pfire_map_path)
entropy1 = calculate_entropy(prob1)
entropy2 = calculate_entropy(prob2)
entropy1_2 = calculate_entropy(prob1_2)

print("Map coefficients of determination (R^2), per dimension:")
mutual_information = entropy1 + entropy2 - entropy1_2

corr_data = []
if return_hist:
return (mutual_information, entropy1, entropy2, entropy1_2, prob1_2)
else:
return (mutual_information, entropy1, entropy2, entropy1_2)

plt.plot(shirt_map[0].flatten(), marker='x', ls='none', label="ShIRT")
plt.plot(pfire_map[0].flatten(), marker='+', ls='none', label="pFIRE")
corr_data.append(sps.linregress(shirt_map[0].flatten(),
pfire_map[0].flatten())[2])
print("X: {:0.3}".format(corr_data[-1]**2))
plt.title("Map X component, R^2={:0.3}".format(corr_data[-1]**2))
plt.legend()
if save_figs:
plt.savefig("{}_map_x.png".format(os.path.splitext(save_figs)[0]))
plt.close()

plt.plot(shirt_map[1].flatten(), marker='x', ls='none', label="ShIRT")
plt.plot(pfire_map[1].flatten(), marker='+', ls='none', label="pFIRE")
corr_data.append(sps.linregress(shirt_map[1].flatten(),
pfire_map[1].flatten())[2])
plt.title("Map Y component, R^2={:0.3}".format(corr_data[-1]**2))
print("Y: {:0.3}".format(corr_data[-1]**2))
plt.legend()
if save_figs:
plt.savefig("{}_map_y.png".format(os.path.splitext(save_figs)[0]))
def plot_2dhist(data, path, title):
""" Helper function to plot 2d histogram and return rst inclusion command.
"""
plt.matshow(data, origin='lower', cmap='gray')
plt.title("\n".join(wrap(title, 40)))
plt.savefig(path)
plt.close()

try:
plt.plot(shirt_map[2].flatten(), marker='x', ls='none', label="ShIRT")
plt.plot(pfire_map[2].flatten(), marker='+', ls='none', label="pFIRE")
corr_data.append(sps.linregress(shirt_map[2].flatten(),
pfire_map[2].flatten())[0])
plt.title("Map Z component, R^2={:0.3}".format(corr_data[2]**2))
print("Z: {:0.3}".format(corr_data[2]**2))
plt.legend()
if save_figs:
plt.savefig("{}_map_z.png".format(os.path.splitext(save_figs)[0]))
plt.close()
except IndexError:
pass

return corr_data
return ".. image:: {}\n".format(os.path.basename(path))


def calculate_proficiency(alpha, beta):
""" Calculate proficiency (normalized mutual information) of an image pair
"""
alpha_data = load_image(alpha)
beta_data = load_image(beta)

res = calculate_mutual_information(alpha_data, beta_data, return_hist=True)

prof = res[0]/min(res[1], res[2])

return MIResult(prof, res[-1])


def compare_image_results(fixed_path, moved_path, accepted_path,
pfire_path, fig_dir=None, cmpname="accepted"):
"""Compare ShIRT and pFIRE registered images
"""
if fig_dir:
os.makedirs(os.path.normpath(fig_dir), mode=0o755, exist_ok=True)
else:
fig_dir = os.path.normpath('.')

mi_start = calculate_proficiency(fixed_path, moved_path)
mi_accepted = calculate_proficiency(fixed_path, accepted_path)
mi_pfire = calculate_proficiency(fixed_path, pfire_path)
mi_comparison = calculate_proficiency(accepted_path, pfire_path)

res_table = [("Normalized mutual information (proficiency):", ""),
("Fixed vs. Moved:", "{:.3f}".format(mi_start.mi)),
("{} vs. Fixed:".format(cmpname), "{:.3f}".format(mi_accepted.mi)),
("pFIRE vs. Fixed:", "{:.3f}".format(mi_pfire.mi)),
("pFIRE vs. {}:".format(cmpname), "{:.3f}\n".format(mi_comparison.mi))]

print(tabulate(res_table, headers="firstrow", tablefmt='grid') + "\n")

rst_output = []
rst_output.append(tabulate(res_table, headers="firstrow", tablefmt="rst"))
rst_output.append("") # table must be followed by blank line

image_rst = []

if fig_dir:
image_rst.append(plot_2dhist(
mi_start.hist, os.path.join(fig_dir, "prereg.png"),
"Fixed vs. Moved normalized mutual information: "
"{:0.3f}".format(mi_start.mi)))

image_rst.append(plot_2dhist(
mi_accepted.hist, os.path.join(fig_dir, "accepted.png"),
"{} vs. Fixed normalized mutual information: "
"{:0.3f}".format(cmpname, mi_accepted.mi)))

image_rst.append(plot_2dhist(
mi_pfire.hist, os.path.join(fig_dir, "pfire.png"),
"pFIRE vs Fixed normalized mutual information: "
"{:0.3f}".format(mi_pfire.mi)))

image_rst.append(plot_2dhist(
mi_comparison.hist, os.path.join(fig_dir, "comparison.png"),
"pFIRE vs. {} normalized mutual information: "
"{:0.3f}".format(cmpname, mi_comparison.mi)))

return ("\n".join(rst_output), "\n".join(image_rst))


def compare_map_results(cmp_map_path, pfire_map_path, fig_dir=None,
cmpname='Accepted'):
"""Compare ShIRT and pFIRE displacement maps
"""
if fig_dir:
os.makedirs(os.path.normpath(fig_dir), mode=0o755, exist_ok=True)

cmp_map = load_map(cmp_map_path)
pfire_map = load_map(pfire_map_path)

table_entries = [("Map coefficients of determination (R^2), by dimension:", "")]
image_entries = []

for didx, dim in enumerate(['X', 'Y', 'Z']):
try:
corr = sps.linregress(cmp_map[didx].flatten(),
pfire_map[didx].flatten())[2]
table_entries.append(("{}:".format(dim), "{:0.3}".format(corr**2)))

if fig_dir:
savepath = os.path.join(fig_dir, "map_{}.png".format(dim.lower()))
plt.plot(cmp_map[didx].flatten(), marker='x', ls='none',
label=cmpname)
plt.plot(pfire_map[didx].flatten(), marker='+', ls='none',
label="pFIRE")
plt.title("Map {} component, R^2={:0.3}".format(dim, corr**2))
plt.legend()
plt.savefig(savepath)
plt.close()
image_entries.append(".. image:: {}"
"".format(os.path.basename(savepath)))
except IndexError:
break

print(tabulate(table_entries, headers="firstrow", tablefmt="grid"))

table = tabulate(table_entries, headers="firstrow", tablefmt="rst")

return (table, "\n".join(image_entries))
Loading

0 comments on commit f4ac8c9

Please sign in to comment.