-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH,REF] Refactor
visualisation
and add temporal importance curves…
… function (#1050) * typo * tic and vis refactor * tic and vis refactor * notebook
- Loading branch information
1 parent
fee139f
commit 5c0d1d2
Showing
23 changed files
with
512 additions
and
452 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Plotting tools for estimators.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
aeon/visualisation/estimator/_temporal_importance_curves.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""Temporal importance curve diagram generators for interval forests.""" | ||
|
||
__author__ = ["MatthewMiddlehurst"] | ||
|
||
__all__ = ["plot_temporal_importance_curves"] | ||
|
||
import numpy as np | ||
|
||
from aeon.utils.validation._dependencies import _check_soft_dependencies | ||
|
||
|
||
def plot_temporal_importance_curves( | ||
curves, curve_names, top_curves_shown=None, plot_mean=True | ||
): | ||
"""Temporal importance curve diagram generator for interval forests.""" | ||
# find attributes to display by max information gain for any time point. | ||
_check_soft_dependencies("matplotlib") | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
top_curves_shown = len(curves) if top_curves_shown is None else top_curves_shown | ||
max_ig = [max(i) for i in curves] | ||
top = sorted(range(len(max_ig)), key=lambda i: max_ig[i], reverse=True)[ | ||
:top_curves_shown | ||
] | ||
|
||
top_curves = [curves[i] for i in top] | ||
top_names = [curve_names[i] for i in top] | ||
|
||
# plot curves with highest max and the mean information gain for each time point if | ||
# enabled. | ||
for i in range(0, top_curves_shown): | ||
plt.plot( | ||
top_curves[i], | ||
label=top_names[i], | ||
) | ||
if plot_mean: | ||
plt.plot( | ||
list(np.mean(curves, axis=0)), | ||
"--", | ||
linewidth=3, | ||
label="Mean Information Gain", | ||
) | ||
plt.legend( | ||
bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), | ||
loc="lower left", | ||
ncol=2, | ||
mode="expand", | ||
borderaxespad=0.0, | ||
) | ||
plt.xlabel("Time Point") | ||
plt.ylabel("Information Gain") | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Plotting tools for estimator results.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
"""Functions for plotting results boxplot diagrams.""" | ||
|
||
__author__ = ["dguijo"] | ||
|
||
__all__ = [ | ||
"plot_boxplot_median", | ||
] | ||
|
||
import numpy as np | ||
|
||
from aeon.utils.validation._dependencies import _check_soft_dependencies | ||
|
||
|
||
def plot_boxplot_median( | ||
results, | ||
labels, | ||
plot_type="violin", | ||
outliers=True, | ||
title=None, | ||
y_min=None, | ||
y_max=None, | ||
): | ||
""" | ||
Plot a box plot of distributions from the median. | ||
Each row of results is an independent experiment for each element in names. This | ||
function works out the deviation from the median for each row, then plots a | ||
boxplot variant of each column. | ||
Parameters | ||
---------- | ||
results: np.array | ||
Scores (either accuracies or errors) of dataset x strategy | ||
labels: list of estimators | ||
List with names of the estimators | ||
plot_type: str, default = "violin" | ||
This function can create four sort of distribution plots: "violin", "swarm", | ||
"boxplot" or "strip". "violin" plot features a kernel density estimation of the | ||
underlying distribution. "swarm" draws a categorical scatterplot with points | ||
adjusted to be non-overlapping. "strip" draws a categorical scatterplot using | ||
jitter to reduce overplotting. | ||
outliers: bool, default = True | ||
Only applies when plot_type is "boxplot". | ||
title: str, default = None | ||
Title to be shown in the top of the plot. | ||
y_min: float, default = None | ||
Min value for the y_axis of the plot. | ||
y_max: float, default = None | ||
Max value for the y_axis of the plot. | ||
Returns | ||
------- | ||
fig: matplotlib.figure | ||
Figure created. | ||
Example | ||
------- | ||
>>> from aeon.visualisation import plot_boxplot_median | ||
>>> from aeon.benchmarking.results_loaders import get_estimator_results_as_array | ||
>>> methods = ["IT", "WEASEL-Dilation", "HIVECOTE2", "FreshPRINCE"] | ||
>>> results = get_estimator_results_as_array(estimators=methods) # doctest: +SKIP | ||
>>> plot = plot_boxplot_median(results[0], methods) # doctest: +SKIP | ||
>>> plot.show() # doctest: +SKIP | ||
>>> plot.savefig("boxplot.pdf") # doctest: +SKIP | ||
""" | ||
_check_soft_dependencies("matplotlib", "seaborn") | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
# Obtains deviation from median for each independent experiment. | ||
medians = np.median(results, axis=1) | ||
sum_results_medians = results + medians[:, np.newaxis] | ||
|
||
deviation_from_median = np.divide( | ||
results, | ||
sum_results_medians, | ||
out=np.zeros_like(results), | ||
where=sum_results_medians != 0, | ||
) | ||
|
||
fig = plt.figure(figsize=(10, 6), layout="tight") | ||
|
||
# Plots violin or boxplots | ||
if plot_type == "violin": | ||
plot = sns.violinplot( | ||
data=deviation_from_median, | ||
linewidth=0.2, | ||
palette="pastel", | ||
bw=0.3, | ||
) | ||
elif plot_type == "boxplot": | ||
plot = sns.boxplot( | ||
data=deviation_from_median, | ||
palette="pastel", | ||
showfliers=outliers, | ||
) | ||
elif plot_type == "swarm": | ||
plot = sns.swarmplot( | ||
data=deviation_from_median, | ||
linewidth=0.2, | ||
palette="pastel", | ||
) | ||
elif plot_type == "strip": | ||
plot = sns.stripplot( | ||
data=deviation_from_median, | ||
linewidth=0.2, | ||
palette="pastel", | ||
) | ||
else: | ||
raise ValueError( | ||
"plot_type must be one of 'violin', 'boxplot', 'swarm' or 'strip'." | ||
) | ||
|
||
# Modifying limits for y-axis. | ||
if y_min is None and ( | ||
(plot_type == "boxplot" and outliers) or (plot_type != "boxplot") | ||
): | ||
y_min = np.around(np.amin(deviation_from_median) - 0.05, 2) | ||
|
||
if y_max is None and ( | ||
(plot_type == "boxplot" and outliers) or (plot_type != "boxplot") | ||
): | ||
y_max = np.around(np.amax(deviation_from_median) + 0.05, 2) | ||
|
||
plot.set_ylim(y_min, y_max) | ||
|
||
# Setting labels for x-axis. Rotate only if labels are too long. | ||
plot.set_xticks(np.arange(len(labels))) | ||
label_lengths = np.array([len(i) for i in labels]) | ||
if (sum(label_lengths) > 40) or (max(label_lengths[:-1] + label_lengths[1:]) > 20): | ||
plot.set_xticklabels(labels, rotation=45, ha="right") | ||
else: | ||
plot.set_xticklabels(labels) | ||
|
||
# Setting title if provided. | ||
if title is not None: | ||
plot.set_title(rf"{title}") | ||
|
||
return fig |
Oops, something went wrong.