Skip to content

Commit

Permalink
refactor: rename into attention_map;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Jul 25, 2024
1 parent b655637 commit 9f9d6aa
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions pypots/utils/visual/visualizeAttention.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
"""
Utilities for attention map visualization.
"""

# Created by Anshuman Swain <aswai@seas.upenn.edu> and Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from numpy.typing import ArrayLike

try:
import seaborn as sns
except Exception:
pass


def visualize_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale = None):
"""Visualize the map of attention weights from Transformer-based models
def plot_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale=None):
"""Visualize the map of attention weights from Transformer-based models.
Parameters
---------------
timeSteps: 1D array-like object, preferable list of strings
A vector containing the time steps of the input.
A vector containing the time steps of the input.
The time steps will be converted to a list of strings if they are not already.
attention: 2D array-like object
Expand All @@ -30,7 +41,7 @@ def visualize_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale =
timeSteps = [str(step) for step in timeSteps]

if fontscale is not None:
sns.set_theme(font_scale = fontscale)
sns.set_theme(font_scale=fontscale)

fig, ax = plt.subplots()
ax.tick_params(left=True, bottom=True, labelsize=10)
Expand Down

0 comments on commit 9f9d6aa

Please sign in to comment.