Skip to content

Commit

Permalink
[Cherry pick] defer matplotlib (#316)
Browse files Browse the repository at this point in the history
* defer matplotlib import error until used

* style

* Added an exact matplotlib version in ImportError
  • Loading branch information
rahul-tuli authored May 22, 2023
1 parent 01eea48 commit de164a0
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/sparsezoo/analyze/utils/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@

import numpy

import matplotlib.pyplot as plt

try:
import matplotlib.pyplot as plt

matplotlib_available = True
except ImportError:
plt = None
matplotlib_available = False

from sparsezoo.analyze.analysis import ModelAnalysis, NodeAnalysis


Expand All @@ -28,6 +36,18 @@
]


def check_matplotlib_installed() -> None:
"""
Checks if matplotlib is installed and
raises an ImportError if not
"""
if not matplotlib_available:
raise ImportError(
"matplotlib is required to use this function, "
"please install it with `pip install matplotlib>=3.0.0`"
)


def draw_sparsity_by_layer_chart(
model_analysis: ModelAnalysis,
out_path: Optional[str] = None,
Expand All @@ -45,6 +65,7 @@ def draw_sparsity_by_layer_chart(
:param figsize: keyword argument to pass to matplotlib figure
:return: None
"""
check_matplotlib_installed()
figure, axes = plt.subplots(figsize=figsize)

# Set title
Expand Down Expand Up @@ -113,6 +134,7 @@ def draw_parameter_chart(
:param figsize: keyword argument to pass to matplotlib figure
:return: None
"""
check_matplotlib_installed()
figure, param_axes = plt.subplots(figsize=figsize)

# Set title
Expand Down Expand Up @@ -200,6 +222,7 @@ def draw_operation_chart(
:param figsize: keyword argument to pass to matplotlib figure
:return: None
"""
check_matplotlib_installed()
figure, ops_axes = plt.subplots(figsize=figsize)

# Set title
Expand Down Expand Up @@ -286,6 +309,7 @@ def draw_parameter_operation_combined_chart(
:param figsize: keyword argument to pass to matplotlib figure
:return: None
"""
check_matplotlib_installed()
figure, param_axes = plt.subplots(figsize=figsize)
ops_axes = param_axes.twinx()

Expand Down

0 comments on commit de164a0

Please sign in to comment.