Skip to content

Commit

Permalink
[ENH,REF] Refactor visualisation and add temporal importance curves…
Browse files Browse the repository at this point in the history
… function (#1050)

* typo

* tic and vis refactor

* tic and vis refactor

* notebook
  • Loading branch information
MatthewMiddlehurst authored Jan 17, 2024
1 parent fee139f commit 5c0d1d2
Show file tree
Hide file tree
Showing 23 changed files with 512 additions and 452 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ body:
```python
from aeon.datasets import load_airline
from `aeon.visualisation import plot_series
from aeon.visualisation import plot_series
y = load_airline()
y = y.to_frame()
Expand Down
5 changes: 0 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ share/python-wheels/
.installed.cfg
*.egg

#TSC results
results/

MANIFEST

#downloaded datasets
aeon/datasets/local_data/

Expand Down
4 changes: 2 additions & 2 deletions aeon/transformations/bootstrap/_mbb.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class STLBootstrapTransformer(BaseTransformer):
--------
>>> from aeon.transformations.bootstrap import STLBootstrapTransformer
>>> from aeon.datasets import load_airline
>>> from `aeon.visualisation import plot_series # doctest: +SKIP
>>> from aeon.visualisation import plot_series # doctest: +SKIP
>>> y = load_airline() # doctest: +SKIP
>>> transformer = STLBootstrapTransformer(10) # doctest: +SKIP
>>> y_hat = transformer.fit_transform(y) # doctest: +SKIP
Expand Down Expand Up @@ -452,7 +452,7 @@ class MovingBlockBootstrapTransformer(BaseTransformer):
--------
>>> from aeon.transformations.bootstrap import MovingBlockBootstrapTransformer
>>> from aeon.datasets import load_airline
>>> from `aeon.visualisation import plot_series # doctest: +SKIP
>>> from aeon.visualisation import plot_series # doctest: +SKIP
>>> y = load_airline()
>>> transformer = MovingBlockBootstrapTransformer(10)
>>> y_hat = transformer.fit_transform(y)
Expand Down
25 changes: 13 additions & 12 deletions aeon/visualisation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,30 @@
"plot_lags",
"plot_correlations",
"plot_windows",
"plot_time_series_with_change_points",
"plot_time_series_with_profiles",
# Results plotting
"plot_critical_difference",
"plot_boxplot_median",
"plot_scatter_predictions",
"plot_scatter",
# Segmentation plotting
"plot_time_series_with_change_points",
"plot_time_series_with_profiles",
# Clustering plotting
# Estimator plotting
"plot_cluster_algorithm",
"plot_temporal_importance_curves",
]
from aeon.visualisation._cluster_plotting import plot_cluster_algorithm
from aeon.visualisation._critical_difference import plot_critical_difference
from aeon.visualisation.results_plotting import (
plot_boxplot_median,
plot_scatter,
plot_scatter_predictions,

from aeon.visualisation.estimator._cluster_plotting import plot_cluster_algorithm
from aeon.visualisation.estimator._temporal_importance_curves import (
plot_temporal_importance_curves,
)
from aeon.visualisation.segmentation_plotting import (
from aeon.visualisation.results._boxplot import plot_boxplot_median
from aeon.visualisation.results._critical_difference import plot_critical_difference
from aeon.visualisation.results._scatter import plot_scatter, plot_scatter_predictions
from aeon.visualisation.series._segmentation import (
plot_time_series_with_change_points,
plot_time_series_with_profiles,
)
from aeon.visualisation.series_plotting import (
from aeon.visualisation.series._series import (
plot_correlations,
plot_interval,
plot_lags,
Expand Down
1 change: 1 addition & 0 deletions aeon/visualisation/estimator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Plotting tools for estimators."""
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Cluster plotting tools."""

__author__ = ["Christopher Holder", "Tony Bagnall"]

__all__ = ["plot_cluster_algorithm"]

import numpy as np
Expand All @@ -10,63 +11,29 @@
from aeon.utils.validation.collection import convert_collection


def _plot(cluster_values, center, axes):
for cluster_series in cluster_values:
for cluster in cluster_series:
axes.plot(cluster, color="b")
axes.plot(center[0], color="r")


def _get_cluster_values(cluster_indexes: np.ndarray, X: np.ndarray, k: int):
ts_in_center = []
for i in range(k):
curr_indexes = np.where(cluster_indexes == i)[0]
ts_in_center.append(X[curr_indexes])

return ts_in_center


def plot_series(X):
_check_soft_dependencies("matplotlib")
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

X = convert_collection(X, "numpy3D")
plt.figure(figsize=(5, 10))
plt.rcParams["figure.dpi"] = 100

fig, axes = plt.subplots(nrows=len(X), ncols=1)
for i in range(len(X)):
curr = X[i][0]
curr_axes = axes[i]
curr_axes.plot(curr, color="b")

blue_patch = mpatches.Patch(color="blue", label="Series that belong to the cluster")
plt.legend(
handles=[blue_patch],
loc="upper center",
bbox_to_anchor=(0.5, -0.40),
fancybox=True,
shadow=True,
ncol=5,
)
plt.tight_layout()
plt.show()


def plot_cluster_algorithm(model: BaseClusterer, X, k: int):
"""Plot the results from a univariate partitioning algorithm.
Parameters
----------
model: BaseClusterer
Clustering model to plot
predict_series: np.ndarray or pd.Dataframe or List[pd.Dataframe]
X: np.ndarray or pd.Dataframe or List[pd.Dataframe]
The series to predict the values for
k: int
Number of centers
Returns
-------
fig: matplotlib.figure
Figure created.
Example
-------
>>> pass
"""
_check_soft_dependencies("matplotlib")

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -95,3 +62,19 @@ def plot_cluster_algorithm(model: BaseClusterer, X, k: int):
)
plt.tight_layout()
plt.show()


def _plot(cluster_values, center, axes):
for cluster_series in cluster_values:
for cluster in cluster_series:
axes.plot(cluster, color="b")
axes.plot(center[0], color="r")


def _get_cluster_values(cluster_indexes: np.ndarray, X: np.ndarray, k: int):
ts_in_center = []
for i in range(k):
curr_indexes = np.where(cluster_indexes == i)[0]
ts_in_center.append(X[curr_indexes])

return ts_in_center
54 changes: 54 additions & 0 deletions aeon/visualisation/estimator/_temporal_importance_curves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Temporal importance curve diagram generators for interval forests."""

__author__ = ["MatthewMiddlehurst"]

__all__ = ["plot_temporal_importance_curves"]

import numpy as np

from aeon.utils.validation._dependencies import _check_soft_dependencies


def plot_temporal_importance_curves(
curves, curve_names, top_curves_shown=None, plot_mean=True
):
"""Temporal importance curve diagram generator for interval forests."""
# find attributes to display by max information gain for any time point.
_check_soft_dependencies("matplotlib")

import matplotlib.pyplot as plt

top_curves_shown = len(curves) if top_curves_shown is None else top_curves_shown
max_ig = [max(i) for i in curves]
top = sorted(range(len(max_ig)), key=lambda i: max_ig[i], reverse=True)[
:top_curves_shown
]

top_curves = [curves[i] for i in top]
top_names = [curve_names[i] for i in top]

# plot curves with highest max and the mean information gain for each time point if
# enabled.
for i in range(0, top_curves_shown):
plt.plot(
top_curves[i],
label=top_names[i],
)
if plot_mean:
plt.plot(
list(np.mean(curves, axis=0)),
"--",
linewidth=3,
label="Mean Information Gain",
)
plt.legend(
bbox_to_anchor=(0.0, 1.02, 1.0, 0.102),
loc="lower left",
ncol=2,
mode="expand",
borderaxespad=0.0,
)
plt.xlabel("Time Point")
plt.ylabel("Information Gain")

plt.show()
1 change: 1 addition & 0 deletions aeon/visualisation/results/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Plotting tools for estimator results."""
139 changes: 139 additions & 0 deletions aeon/visualisation/results/_boxplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Functions for plotting results boxplot diagrams."""

__author__ = ["dguijo"]

__all__ = [
"plot_boxplot_median",
]

import numpy as np

from aeon.utils.validation._dependencies import _check_soft_dependencies


def plot_boxplot_median(
results,
labels,
plot_type="violin",
outliers=True,
title=None,
y_min=None,
y_max=None,
):
"""
Plot a box plot of distributions from the median.
Each row of results is an independent experiment for each element in names. This
function works out the deviation from the median for each row, then plots a
boxplot variant of each column.
Parameters
----------
results: np.array
Scores (either accuracies or errors) of dataset x strategy
labels: list of estimators
List with names of the estimators
plot_type: str, default = "violin"
This function can create four sort of distribution plots: "violin", "swarm",
"boxplot" or "strip". "violin" plot features a kernel density estimation of the
underlying distribution. "swarm" draws a categorical scatterplot with points
adjusted to be non-overlapping. "strip" draws a categorical scatterplot using
jitter to reduce overplotting.
outliers: bool, default = True
Only applies when plot_type is "boxplot".
title: str, default = None
Title to be shown in the top of the plot.
y_min: float, default = None
Min value for the y_axis of the plot.
y_max: float, default = None
Max value for the y_axis of the plot.
Returns
-------
fig: matplotlib.figure
Figure created.
Example
-------
>>> from aeon.visualisation import plot_boxplot_median
>>> from aeon.benchmarking.results_loaders import get_estimator_results_as_array
>>> methods = ["IT", "WEASEL-Dilation", "HIVECOTE2", "FreshPRINCE"]
>>> results = get_estimator_results_as_array(estimators=methods) # doctest: +SKIP
>>> plot = plot_boxplot_median(results[0], methods) # doctest: +SKIP
>>> plot.show() # doctest: +SKIP
>>> plot.savefig("boxplot.pdf") # doctest: +SKIP
"""
_check_soft_dependencies("matplotlib", "seaborn")
import matplotlib.pyplot as plt
import seaborn as sns

# Obtains deviation from median for each independent experiment.
medians = np.median(results, axis=1)
sum_results_medians = results + medians[:, np.newaxis]

deviation_from_median = np.divide(
results,
sum_results_medians,
out=np.zeros_like(results),
where=sum_results_medians != 0,
)

fig = plt.figure(figsize=(10, 6), layout="tight")

# Plots violin or boxplots
if plot_type == "violin":
plot = sns.violinplot(
data=deviation_from_median,
linewidth=0.2,
palette="pastel",
bw=0.3,
)
elif plot_type == "boxplot":
plot = sns.boxplot(
data=deviation_from_median,
palette="pastel",
showfliers=outliers,
)
elif plot_type == "swarm":
plot = sns.swarmplot(
data=deviation_from_median,
linewidth=0.2,
palette="pastel",
)
elif plot_type == "strip":
plot = sns.stripplot(
data=deviation_from_median,
linewidth=0.2,
palette="pastel",
)
else:
raise ValueError(
"plot_type must be one of 'violin', 'boxplot', 'swarm' or 'strip'."
)

# Modifying limits for y-axis.
if y_min is None and (
(plot_type == "boxplot" and outliers) or (plot_type != "boxplot")
):
y_min = np.around(np.amin(deviation_from_median) - 0.05, 2)

if y_max is None and (
(plot_type == "boxplot" and outliers) or (plot_type != "boxplot")
):
y_max = np.around(np.amax(deviation_from_median) + 0.05, 2)

plot.set_ylim(y_min, y_max)

# Setting labels for x-axis. Rotate only if labels are too long.
plot.set_xticks(np.arange(len(labels)))
label_lengths = np.array([len(i) for i in labels])
if (sum(label_lengths) > 40) or (max(label_lengths[:-1] + label_lengths[1:]) > 20):
plot.set_xticklabels(labels, rotation=45, ha="right")
else:
plot.set_xticklabels(labels)

# Setting title if provided.
if title is not None:
plot.set_title(rf"{title}")

return fig
Loading

0 comments on commit 5c0d1d2

Please sign in to comment.