Skip to content

Commit

Permalink
expand style functionality (#75)
Browse files Browse the repository at this point in the history
* expand style functionality

* add feedback

* Apply suggestions from code review

Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>

---------

Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>
  • Loading branch information
aloctavodia and OriolAbril authored Aug 1, 2024
1 parent ad8d487 commit 60aac9a
Showing 1 changed file with 65 additions and 2 deletions.
67 changes: 65 additions & 2 deletions src/arviz_plots/style.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Style/templating helpers."""

from arviz_base import rcParams


def use(name):
"""Set an arviz style as the default style/template for all available backends.
Expand All @@ -13,16 +15,77 @@ def use(name):
-----
There are some backends where default styles/templates are not supported.
"""
ok = False

try:
import matplotlib.pyplot as plt

if name in plt.style.available:
plt.style.use(name)
ok = True
except ImportError:
pass

try:
import plotly.io as pio

if name in pio.templates:
pio.templates.default = name
ok = True
except ImportError:
pass

if not ok:
raise ValueError(f"Style {name} not found.")


def available():
"""List available styles."""
styles = {}

try:
import matplotlib.pyplot as plt

plt.style.use(name)
styles["matplotlib"] = plt.style.available
except ImportError:
pass

try:
import plotly.io as pio

pio.templates.default = name
styles["plotly"] = list(pio.templates)
except ImportError:
pass

return styles


def get(name, backend=None):
"""Get the style/template with the given name.
Parameters
----------
name : str
Name of the style/template to get.
backend : str
Name of the backend to use. Options are 'matplotlib' and 'plotly'.
Defaults to ``rcParams["plot.backend"]``.
"""
if backend is None:
backend = rcParams["plot.backend"]
if backend not in ["matplotlib", "plotly"]:
raise ValueError(f"Default styles/templates are not supported for Backend {backend}")

if backend == "matplotlib":
import matplotlib.pyplot as plt

if name in plt.style.available:
return plt.style.library[name]

elif backend == "plotly":
import plotly.io as pio

if name in pio.templates:
return pio.templates[name]

raise ValueError(f"Style {name} not found.")

0 comments on commit 60aac9a

Please sign in to comment.