diff --git a/pypots/utils/visual/visualizeAttention.py b/pypots/utils/visual/visualizeAttention.py index 3c6b6899..7eaf6dd4 100644 --- a/pypots/utils/visual/visualizeAttention.py +++ b/pypots/utils/visual/visualizeAttention.py @@ -1,16 +1,27 @@ +""" +Utilities for attention map visualization. +""" + +# Created by Anshuman Swain and Wenjie Du +# License: BSD-3-Clause + import matplotlib.pyplot as plt import numpy as np -import seaborn as sns from numpy.typing import ArrayLike +try: + import seaborn as sns +except Exception: + pass + -def visualize_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale = None): - """Visualize the map of attention weights from Transformer-based models +def plot_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale=None): + """Visualize the map of attention weights from Transformer-based models. Parameters --------------- timeSteps: 1D array-like object, preferable list of strings - A vector containing the time steps of the input. + A vector containing the time steps of the input. The time steps will be converted to a list of strings if they are not already. attention: 2D array-like object @@ -30,7 +41,7 @@ def visualize_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale = timeSteps = [str(step) for step in timeSteps] if fontscale is not None: - sns.set_theme(font_scale = fontscale) + sns.set_theme(font_scale=fontscale) fig, ax = plt.subplots() ax.tick_params(left=True, bottom=True, labelsize=10)