Skip to content

Commit

Permalink
Merge pull request ME-ICA#4 from handwerkerd/DecisionTreeImprovements
Browse files Browse the repository at this point in the history
Decision tree improvements
  • Loading branch information
jbteves authored Nov 12, 2021
2 parents 6dd802c + d7fe729 commit 9a68e48
Show file tree
Hide file tree
Showing 11 changed files with 2,976 additions and 2,393 deletions.
246 changes: 246 additions & 0 deletions docs/building_decision_trees.rst

Large diffs are not rendered by default.

86 changes: 39 additions & 47 deletions tedana/metrics/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def generate_metrics(
n_components = mixing.shape[1]
comptable = pd.DataFrame(index=np.arange(n_components, dtype=int))
comptable["Component"] = [
io.add_decomp_prefix(
comp, prefix=label, max_value=comptable.shape[0]
)
io.add_decomp_prefix(comp, prefix=label, max_value=comptable.shape[0])
for comp in comptable.index.values
]

Expand All @@ -140,9 +138,7 @@ def generate_metrics(
data_optcom, mixing
)
if io_generator.verbose:
metric_maps["map echo betas"] = dependence.calculate_betas(
data_cat, mixing
)
metric_maps["map echo betas"] = dependence.calculate_betas(data_cat, mixing)

if "map percent signal change" in required_metrics:
LGR.info("Calculating percent signal change maps")
Expand All @@ -158,7 +154,7 @@ def generate_metrics(
if io_generator.verbose:
io_generator.save_file(
utils.unmask(metric_maps["map Z"] ** 2, mask),
label + ' component weights img',
label + " component weights img",
)

if ("map FT2" in required_metrics) or ("map FS0" in required_metrics):
Expand Down Expand Up @@ -341,37 +337,47 @@ def generate_metrics(
echo_betas = betas[:, i_echo, :]
io_generator.save_file(
utils.unmask(echo_betas, mask),
'echo weight ' + label + ' map split img',
echo=(i_echo + 1)
"echo weight " + label + " map split img",
echo=(i_echo + 1),
)

if write_T2S0:
echo_pred_T2_maps = pred_T2_maps[:, i_echo, :]
io_generator.save_file(
utils.unmask(echo_pred_T2_maps, mask),
'echo T2 ' + label + ' split img',
echo=(i_echo + 1)
"echo T2 " + label + " split img",
echo=(i_echo + 1),
)

echo_pred_S0_maps = pred_S0_maps[:, i_echo, :]
io_generator.save_file(
utils.unmask(echo_pred_S0_maps, mask),
'echo S0 ' + label + ' split img',
echo=(i_echo + 1)
"echo S0 " + label + " split img",
echo=(i_echo + 1),
)

# Reorder component table columns based on previous tedana versions
# NOTE: Some new columns will be calculated and columns may be reordered during
# component selection
preferred_order = (
"Component", "kappa", "rho", "variance explained",
"Component",
"kappa",
"rho",
"variance explained",
"normalized variance explained",
"estimated normalized variance explained",
"countsigFT2", "countsigFS0",
"dice_FT2", "dice_FS0",
"countnoise", "signal-noise_t", "signal-noise_p",
"d_table_score", "kappa ratio", "d_table_score_scrub",
"classification", "rationale",
"countsigFT2",
"countsigFS0",
"dice_FT2",
"dice_FS0",
"countnoise",
"signal-noise_t",
"signal-noise_p",
"d_table_score",
"kappa ratio",
"d_table_score_scrub",
"classification",
"rationale",
)
first_columns = [col for col in preferred_order if col in comptable.columns]
other_columns = [col for col in comptable.columns if col not in preferred_order]
Expand Down Expand Up @@ -468,9 +474,7 @@ def get_metadata(comptable):
}
if "dice_FS0" in comptable:
metric_metadata["dice_FS0"] = {
"LongName": (
"S0 model beta map-F-statistic map Dice similarity index"
),
"LongName": ("S0 model beta map-F-statistic map Dice similarity index"),
"Description": (
"Dice value of cluster-extent thresholded maps of "
"S0-model betas and F-statistics."
Expand Down Expand Up @@ -519,13 +523,10 @@ def get_metadata(comptable):
if "original_classification" in comptable:
metric_metadata["original_classification"] = {
"LongName": "Original classification",
"Description": (
"Classification from the original decision tree."
),
"Description": ("Classification from the original decision tree."),
"Levels": {
"accepted": (
"A BOLD-like component included in denoised and "
"high-Kappa data."
"A BOLD-like component included in denoised and " "high-Kappa data."
),
"rejected": (
"A non-BOLD component excluded from denoised and "
Expand All @@ -537,43 +538,34 @@ def get_metadata(comptable):
),
},
}
if "original_rationale" in comptable:
metric_metadata["original_rationale"] = {
"LongName": "Original rationale",
"Description": (
"The reason for the original classification. "
"Please see tedana's documentation for information about "
"possible rationales."
),
}

if "classification" in comptable:
metric_metadata["classification"] = {
"LongName": "Component classification",
"Description": (
"Classification from the manual classification procedure."
),
"Description": ("Classification from the manual classification procedure."),
"Levels": {
"accepted": (
"A BOLD-like component included in denoised and "
"high-Kappa data."
"A BOLD-like component included in denoised and " "high-Kappa data."
),
"rejected": (
"A non-BOLD component excluded from denoised and "
"high-Kappa data."
),
"ignored": (
"A low-variance component included in denoised, "
"but excluded from high-Kappa data."
),
},
}
if "classification_tags" in comptable:
metric_metadata["classification_tags"] = {
"LongName": "Component classification tags",
"Description": (
"A single tag or a comma separated list of tags to describe why a component received its classification"
),
}
if "rationale" in comptable:
metric_metadata["rationale"] = {
"LongName": "Rationale for component classification",
"Description": (
"The reason for the original classification. "
"Please see tedana's documentation for information about "
"possible rationales."
"This column label was replaced with classification_tags in late 2021"
),
}
if "kappa ratio" in comptable:
Expand Down
98 changes: 57 additions & 41 deletions tedana/reporting/static_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@

import numpy as np
import matplotlib
matplotlib.use('AGG')

matplotlib.use("AGG")
import matplotlib.pyplot as plt
from nilearn import plotting

from tedana import io, stats, utils

LGR = logging.getLogger("GENERAL")
MPL_LGR = logging.getLogger('matplotlib')
MPL_LGR = logging.getLogger("matplotlib")
MPL_LGR.setLevel(logging.WARNING)
RepLGR = logging.getLogger('REPORT')
RefLGR = logging.getLogger('REFERENCES')
RepLGR = logging.getLogger("REPORT")
RefLGR = logging.getLogger("REFERENCES")


def _trim_edge_zeros(arr):
Expand All @@ -38,12 +39,14 @@ def _trim_edge_zeros(arr):

mask = arr != 0
bounding_box = tuple(
slice(np.min(indexes), np.max(indexes) + 1)
for indexes in np.where(mask))
slice(np.min(indexes), np.max(indexes) + 1) for indexes in np.where(mask)
)
return arr[bounding_box]


def carpet_plot(optcom_ts, denoised_ts, hikts, lowkts, mask, io_generator, gscontrol=None):
def carpet_plot(
optcom_ts, denoised_ts, hikts, lowkts, mask, io_generator, gscontrol=None
):
"""Generate a set of carpet plots for the combined and denoised data.
Parameters
Expand Down Expand Up @@ -122,7 +125,9 @@ def carpet_plot(optcom_ts, denoised_ts, hikts, lowkts, mask, io_generator, gscon
title="Optimally Combined Data (Pre-GSR)",
)
fig.tight_layout()
fig.savefig(os.path.join(io_generator.out_dir, "figures", "carpet_optcom_nogsr.svg"))
fig.savefig(
os.path.join(io_generator.out_dir, "figures", "carpet_optcom_nogsr.svg")
)

if (gscontrol is not None) and ("mir" in gscontrol):
mir_denoised_img = io_generator.get_name("mir denoised img")
Expand All @@ -135,7 +140,9 @@ def carpet_plot(optcom_ts, denoised_ts, hikts, lowkts, mask, io_generator, gscon
title="Denoised Data (Post-MIR)",
)
fig.tight_layout()
fig.savefig(os.path.join(io_generator.out_dir, "figures", "carpet_denoised_mir.svg"))
fig.savefig(
os.path.join(io_generator.out_dir, "figures", "carpet_denoised_mir.svg")
)

mir_denoised_img = io_generator.get_name("ICA accepted mir denoised img")
fig, ax = plt.subplots(figsize=(14, 7))
Expand All @@ -147,7 +154,9 @@ def carpet_plot(optcom_ts, denoised_ts, hikts, lowkts, mask, io_generator, gscon
title="High-Kappa Data (Post-MIR)",
)
fig.tight_layout()
fig.savefig(os.path.join(io_generator.out_dir, "figures", "carpet_accepted_mir.svg"))
fig.savefig(
os.path.join(io_generator.out_dir, "figures", "carpet_accepted_mir.svg")
)


def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
Expand Down Expand Up @@ -191,32 +200,36 @@ def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):

# Create indices for 6 cuts, based on dimensions
cuts = [ts_B.shape[dim] // 6 for dim in range(3)]
expl_text = ''
expl_text = ""

# Remove trailing ';' from rationale column
comptable['rationale'] = comptable['rationale'].str.rstrip(';')
# comptable['rationale'] = comptable['rationale'].str.rstrip(';')
for compnum in comptable.index.values:
if comptable.loc[compnum, "classification"] == 'accepted':
line_color = 'g'
expl_text = 'accepted'
elif comptable.loc[compnum, "classification"] == 'rejected':
line_color = 'r'
expl_text = 'rejection reason(s): ' + comptable.loc[compnum, "rationale"]
elif comptable.loc[compnum, "classification"] == 'ignored':
line_color = 'k'
expl_text = 'ignored reason(s): ' + comptable.loc[compnum, "rationale"]
if comptable.loc[compnum, "classification"] == "accepted":
line_color = "g"
expl_text = (
"accepted reason(s): " + comptable.loc[compnum, "classification_tags"]
)
elif comptable.loc[compnum, "classification"] == "rejected":
line_color = "r"
expl_text = (
"rejection reason(s): " + comptable.loc[compnum, "classification_tags"]
)
elif comptable.loc[compnum, "classification"] == "ignored":
line_color = "k"
expl_text = (
"ignored reason(s): " + comptable.loc[compnum, "classification_tags"]
)
else:
# Classification not added
# If new, this will keep code running
line_color = '0.75'
expl_text = 'other classification'
line_color = "0.75"
expl_text = "other classification"

allplot = plt.figure(figsize=(10, 9))
ax_ts = plt.subplot2grid((5, 6), (0, 0),
rowspan=1, colspan=6,
fig=allplot)
ax_ts = plt.subplot2grid((5, 6), (0, 0), rowspan=1, colspan=6, fig=allplot)

ax_ts.set_xlabel('TRs')
ax_ts.set_xlabel("TRs")
ax_ts.set_xlim(0, n_vols)
plt.yticks([])
# Make a second axis with units of time (s)
Expand All @@ -235,17 +248,17 @@ def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
ax_ts2.set_xticks(ax1Xs)
ax_ts2.set_xlim(ax_ts.get_xbound())
ax_ts2.set_xticklabels(ax2Xs)
ax_ts2.set_xlabel('seconds')
ax_ts2.set_xlabel("seconds")

ax_ts.plot(mmix[:, compnum], color=line_color)

# Title will include variance from comptable
comp_var = "{0:.2f}".format(comptable.loc[compnum, "variance explained"])
comp_kappa = "{0:.2f}".format(comptable.loc[compnum, "kappa"])
comp_rho = "{0:.2f}".format(comptable.loc[compnum, "rho"])
plt_title = ('Comp. {}: variance: {}%, kappa: {}, rho: {}, '
'{}'.format(compnum, comp_var, comp_kappa, comp_rho,
expl_text))
plt_title = "Comp. {}: variance: {}%, kappa: {}, rho: {}, " "{}".format(
compnum, comp_var, comp_kappa, comp_rho, expl_text
)
title = ax_ts.set_title(plt_title)
title.set_y(1.5)

Expand All @@ -255,8 +268,10 @@ def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):

for idx, _ in enumerate(cuts):
for imgslice in range(1, 6):
ax = plt.subplot2grid((5, 6), (idx + 1, imgslice - 1), rowspan=1, colspan=1)
ax.axis('off')
ax = plt.subplot2grid(
(5, 6), (idx + 1, imgslice - 1), rowspan=1, colspan=1
)
ax.axis("off")

if idx == 0:
to_plot = np.rot90(ts_B[imgslice * cuts[idx], :, :, compnum])
Expand All @@ -265,14 +280,15 @@ def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
if idx == 2:
to_plot = ts_B[:, :, imgslice * cuts[idx], compnum]

ax_im = ax.imshow(to_plot, vmin=imgmin, vmax=imgmax, aspect='equal',
cmap=png_cmap)
ax_im = ax.imshow(
to_plot, vmin=imgmin, vmax=imgmax, aspect="equal", cmap=png_cmap
)

# Add a color bar to the plot.
ax_cbar = allplot.add_axes([0.8, 0.3, 0.03, 0.37])
cbar = allplot.colorbar(ax_im, ax_cbar)
cbar.set_label('Component Beta', rotation=90)
cbar.ax.yaxis.set_label_position('left')
cbar.set_label("Component Beta", rotation=90)
cbar.ax.yaxis.set_label_position("left")

# Get fft and freqs for this subject
# adapted from @dangom
Expand All @@ -281,14 +297,14 @@ def comp_figures(ts, mask, comptable, mmix, io_generator, png_cmap):
# Plot it
ax_fft = plt.subplot2grid((5, 6), (4, 0), rowspan=1, colspan=6)
ax_fft.plot(freqs, spectrum)
ax_fft.set_title('One Sided fft')
ax_fft.set_xlabel('Hz')
ax_fft.set_title("One Sided fft")
ax_fft.set_xlabel("Hz")
ax_fft.set_xlim(freqs[0], freqs[-1])
plt.yticks([])

# Fix spacing so TR label does overlap with other plots
allplot.subplots_adjust(hspace=0.4)
plot_name = 'comp_{}.png'.format(str(compnum).zfill(3))
compplot_name = os.path.join(io_generator.out_dir, 'figures', plot_name)
plot_name = "comp_{}.png".format(str(compnum).zfill(3))
compplot_name = os.path.join(io_generator.out_dir, "figures", plot_name)
plt.savefig(compplot_name)
plt.close()
Loading

0 comments on commit 9a68e48

Please sign in to comment.