Skip to content

Commit

Permalink
add html report generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Phil Tooley committed May 7, 2019
1 parent 536773e commit b7c0202
Show file tree
Hide file tree
Showing 6 changed files with 391 additions and 255 deletions.
263 changes: 162 additions & 101 deletions benchmarking/pfire_benchmarking/analysis_routines.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,185 @@
#!/usr/bin/env python3

""" Mathematical analysis functions for image and map comparison
"""

from collections import namedtuple
from textwrap import wrap
import os

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sps

from tabulate import tabulate

import flannel.io as fio

from .image_routines import (calculate_mutual_information,
load_image,
load_pfire_map)
from .image_routines import load_image, load_pfire_map

MIResult = namedtuple("mi_result", ['mi', 'hist'])

def calculate_entropy(prob_dist):
r"""
Calculate Shannon entropy of the provided probability distribution
Shannon Entropy is defined as
$H(X) = \sum_n p_X(x)\log_2{p_X(x)}$
"""
# First disregard all values where p_X == 0 to avoid nans from log(p_X)
normed_prob_dist = prob_dist[prob_dist > 0]
normed_prob_dist /= normed_prob_dist.sum()
entropy = -np.sum(normed_prob_dist * np.log2(normed_prob_dist))

return entropy


def calculate_mutual_information(data1, data2, resolution=50,
return_hist=False):
r"""
Calculate mutual information using Shannon entropy of provided data.
Mutual Information is defined as:
MI(X, Y) = H(X) + H(Y) - H(X,Y)
Where H(X), H(Y) and H(X,Y) are the Shannon entropies of the probabilities
and the joint probability of the data.
N.B it is assumed that the two datasets are independent.
Returns a tuple of MI(X,Y), H(X), H(Y), H(X,Y)
"""
jointmax = max(data1.max(), data2.max())
# First calculate probability density
bin_edges = np.linspace(0, 1, num=resolution)
prob1_2, _, _ = np.histogram2d(data1.flatten()/jointmax,
data2.flatten()/jointmax,
bins=bin_edges, density=True)
prob1 = np.sum(prob1_2, axis=1)
prob2 = np.sum(prob1_2, axis=0)

entropy1 = calculate_entropy(prob1)
entropy2 = calculate_entropy(prob2)
entropy1_2 = calculate_entropy(prob1_2)

mutual_information = entropy1 + entropy2 - entropy1_2

if return_hist:
return (mutual_information, entropy1, entropy2, entropy1_2, prob1_2)
else:
return (mutual_information, entropy1, entropy2, entropy1_2)


def plot_2dhist(data, path, title):
""" Helper function to plot 2d histogram and return rst inclusion command.
"""
plt.matshow(data, origin='lower', cmap='gray')
plt.title("\n".join(wrap(title, 40)))
plt.savefig(path)
plt.close()

return ".. image:: {}\n".format(path)

def compare_image_results(fixed_path, moved_path, shirt_registered_path,
pfire_registered_path, save_figs=None):

def calculate_proficiency(alpha, beta):
""" Calculate proficiency (normalized mutual information) of an image pair
"""
alpha_data = load_image(alpha)
beta_data = load_image(beta)

res = calculate_mutual_information(alpha_data, beta_data, return_hist=True)

prof = res[0]/min(res[1], res[2])

return MIResult(prof, res[-1])


def compare_image_results(fixed_path, moved_path, accepted_path,
pfire_path, fig_dir=None, cmpname="accepted"):
"""Compare ShIRT and pFIRE registered images
"""
if fig_dir:
os.makedirs(os.path.normpath(fig_dir), mode=0o755, exist_ok=True)

mi_start = calculate_proficiency(fixed_path, moved_path)
mi_accepted = calculate_proficiency(fixed_path, accepted_path)
mi_pfire = calculate_proficiency(fixed_path, pfire_path)
mi_comparison = calculate_proficiency(accepted_path, pfire_path)

res_table = [("Normalized mutual information (proficiency):", ""),
("Fixed vs. Moved:", "{:.3f}".format(mi_start.mi)),
("{} vs. Fixed:".format(cmpname), "{:.3f}".format(mi_accepted.mi)),
("pFIRE vs. Fixed:", "{:.3f}".format(mi_pfire.mi)),
("pFIRE vs. {}:".format(cmpname), "{:.3f}\n".format(mi_comparison.mi))]

print(tabulate(res_table, headers="firstrow", tablefmt='grid') + "\n")

rst_output = []
rst_output.append(tabulate(res_table, headers="firstrow", tablefmt="rst"))
rst_output.append("") # table must be followed by blank line

fixed_data = load_image(fixed_path)
moved_data = load_image(moved_path)
shirt_data = load_image(shirt_registered_path)
pfire_data = load_image(pfire_registered_path)

mi_start = calculate_mutual_information(fixed_data, moved_data,
return_hist=True)
mi_shirt = calculate_mutual_information(fixed_data, shirt_data,
return_hist=True)
mi_pfire = calculate_mutual_information(fixed_data, pfire_data,
return_hist=True)
mi_comparison = calculate_mutual_information(shirt_data, pfire_data,
return_hist=True)

prof_start = mi_start[0]/mi_start[1]
prof_shirt = mi_shirt[0]/min(mi_shirt[1], mi_shirt[2])
prof_pfire = mi_pfire[0]/min(mi_pfire[1], mi_pfire[2])
rdcy_norm = mi_comparison[0]/min(mi_comparison[1], mi_comparison[2])


print("Normalized mutual information (proficiency):")
print("Pre registration: {:.3f}".format(prof_start))
print("Shirt registration: {:.3f}".format(prof_shirt))
print("pFIRE registration: {:.3f}".format(prof_pfire))
print("Similarity between results (normalized redundancy):")
print("ShIRT vs. pFIRE: {:.3f}\n".format(rdcy_norm))

if save_figs:
plt.matshow(mi_start[-1], origin='lower')
plt.title("\n".join(wrap("Pre-registration normalized mutual information: "
"{:0.3f}".format(prof_start), 40)))
plt.savefig("{}_prereg.png".format(os.path.splitext(save_figs)[0]))
plt.close()

plt.matshow(mi_shirt[-1], origin='lower')
plt.title("\n".join(wrap("ShIRT registration normalized mutual information: "
"{:0.3f}".format(prof_shirt), 40)))
plt.savefig("{}_shirt.png".format(os.path.splitext(save_figs)[0]))
plt.close()

plt.matshow(mi_pfire[-1], origin='lower')
plt.title("\n".join(wrap("pFIRE registration normalized mutual information: "
"{:0.3f}".format(prof_pfire), 40)))
plt.savefig("{}_pfire.png".format(os.path.splitext(save_figs)[0]))
plt.close()

plt.matshow(mi_comparison[-1], origin='lower')
plt.title("\n".join(wrap("ShIRT - pFIRE comparison (normalized redundancy): "
"{:0.3f}".format(rdcy_norm), 40)))
plt.savefig("{}_comparison.png".format(os.path.splitext(save_figs)[0]))
plt.close()

mi_data = (prof_start, prof_shirt, prof_pfire, rdcy_norm)

return mi_data


def compare_map_results(shirt_map_path, pfire_map_path, save_figs=None):
image_rst = []

if fig_dir:
image_rst.append(plot_2dhist(
mi_start.hist, os.path.join(fig_dir, "prereg.png"),
"Fixed vs. Moved normalized mutual information: "
"{:0.3f}".format(mi_start.mi)))

image_rst.append(plot_2dhist(
mi_accepted.hist, os.path.join(fig_dir, "accepted.png"),
"{} vs. Fixed normalized mutual information: "
"{:0.3f}".format(cmpname, mi_accepted.mi)))

image_rst.append(plot_2dhist(
mi_pfire.hist, os.path.join(fig_dir, "pfire.png"),
"pFIRE vs Fixed normalized mutual information: "
"{:0.3f}".format(mi_pfire.mi)))

image_rst.append(plot_2dhist(
mi_comparison.hist, os.path.join(fig_dir, "comparison.png"),
"pFIRE vs. {} normalized mutual information: "
"{:0.3f}".format(cmpname, mi_comparison.mi)))

return ("\n".join(rst_output), "\n".join(image_rst))


def compare_map_results(cmp_map_path, pfire_map_path, fig_dir=None,
cmpname='Accepted'):
"""Compare ShIRT and pFIRE displacement maps
"""
if fig_dir:
os.makedirs(os.path.normpath(fig_dir), mode=0o755, exist_ok=True)

shirt_map = fio.load_map(shirt_map_path)[0][0:3]
cmp_map = fio.load_map(cmp_map_path)[0][0:3]
pfire_map = load_pfire_map(pfire_map_path)

print("Map coefficients of determination (R^2), per dimension:")
table_entries = [("Map coefficients of determination (R^2), by dimension:", "")]
image_entries = []

corr_data = []

plt.plot(shirt_map[0].flatten(), marker='x', ls='none', label="ShIRT")
plt.plot(pfire_map[0].flatten(), marker='+', ls='none', label="pFIRE")
corr_data.append(sps.linregress(shirt_map[0].flatten(),
pfire_map[0].flatten())[2])
print("X: {:0.3}".format(corr_data[-1]**2))
plt.title("Map X component, R^2={:0.3}".format(corr_data[-1]**2))
plt.legend()
if save_figs:
plt.savefig("{}_map_x.png".format(os.path.splitext(save_figs)[0]))
plt.close()
for didx, dim in enumerate(['X', 'Y', 'Z']):
try:
corr = sps.linregress(cmp_map[didx].flatten(),
pfire_map[didx].flatten())[2]
table_entries.append(("{}:".format(dim), "{:0.3}".format(corr**2)))

plt.plot(shirt_map[1].flatten(), marker='x', ls='none', label="ShIRT")
plt.plot(pfire_map[1].flatten(), marker='+', ls='none', label="pFIRE")
corr_data.append(sps.linregress(shirt_map[1].flatten(),
pfire_map[1].flatten())[2])
plt.title("Map Y component, R^2={:0.3}".format(corr_data[-1]**2))
print("Y: {:0.3}".format(corr_data[-1]**2))
plt.legend()
if save_figs:
plt.savefig("{}_map_y.png".format(os.path.splitext(save_figs)[0]))
plt.close()
if fig_dir:
savepath = os.path.join(fig_dir, "map_{}.png".format(dim.lower()))
plt.plot(cmp_map[didx].flatten(), marker='x', ls='none',
label=cmpname)
plt.plot(pfire_map[didx].flatten(), marker='+', ls='none',
label="pFIRE")
plt.title("Map {} component, R^2={:0.3}".format(dim, corr**2))
plt.legend()
plt.savefig(savepath)
plt.close()
image_entries.append(".. image:: {}".format(savepath))
except IndexError:
break

print(tabulate(table_entries, headers="firstrow", tablefmt="grid"))

table = tabulate(table_entries, headers="firstrow", tablefmt="rst")

try:
plt.plot(shirt_map[2].flatten(), marker='x', ls='none', label="ShIRT")
plt.plot(pfire_map[2].flatten(), marker='+', ls='none', label="pFIRE")
corr_data.append(sps.linregress(shirt_map[2].flatten(),
pfire_map[2].flatten())[0])
plt.title("Map Z component, R^2={:0.3}".format(corr_data[2]**2))
print("Z: {:0.3}".format(corr_data[2]**2))
plt.legend()
if save_figs:
plt.savefig("{}_map_z.png".format(os.path.splitext(save_figs)[0]))
plt.close()
except IndexError:
pass

return corr_data
return (table, "\n".join(image_entries))
Loading

0 comments on commit b7c0202

Please sign in to comment.