Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to switch between plotting backends #326

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/gemdat/_plot_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from types import ModuleType

from gemdat import plots as plots_default
from gemdat.plots import matplotlib as plots_matplotlib
from gemdat.plots import plotly as plots_plotly


def plot_backend(func):
"""Decorator to switch plotting backend."""

def wrap(*args, backend: str | None = None, **kwargs):
module: ModuleType

if backend is None:
module = plots_default
elif backend in ('mpl', 'matplotlib'):
module = plots_matplotlib
elif backend == 'plotly':
module = plots_plotly
else:
raise ValueError(f'No such backend: {backend}')

result = func(*args, module=module, **kwargs)

return result

wrap.__doc__ = func.__doc__
wrap.__doc__ += """

Parameters
---------
backend : str
Choose plotting backend. Options: matplotlib, mpl, plotly
Defaults to plotly unless the plot is only available in matplotlib.

Returns
-------
fig : plotly.graph_objects.Figure or matplotlib.figure.Figure depending on backend.
Output figure
"""

return wrap
36 changes: 16 additions & 20 deletions src/gemdat/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pymatgen.core.units import FloatWithUnit
from scipy.constants import Boltzmann, angstrom, elementary_charge

from ._plot_backend import plot_backend
from .caching import weak_lru_cache
from .collective import Collective
from .simulation_metrics import SimulationMetrics
Expand Down Expand Up @@ -353,32 +354,27 @@ def rates(self, n_parts: int = 10) -> pd.DataFrame:

return df

def plot_jumps_vs_distance(self, **kwargs):
@plot_backend
def plot_jumps_vs_distance(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_vs_distance][] for more information."""
from gemdat import plots
return module.jumps_vs_distance(jumps=self, **kwargs)

return plots.jumps_vs_distance(jumps=self, **kwargs)

def plot_jumps_vs_time(self, **kwargs):
@plot_backend
def plot_jumps_vs_time(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_vs_time][] for more information."""
from gemdat import plots

return plots.jumps_vs_time(jumps=self, **kwargs)
return module.jumps_vs_time(jumps=self, **kwargs)

def plot_collective_jumps(self, **kwargs):
@plot_backend
def plot_collective_jumps(self, *, module, **kwargs):
"""See [gemdat.plots.collective_jumps][] for more information."""
from gemdat import plots

return plots.collective_jumps(jumps=self, **kwargs)
return module.collective_jumps(jumps=self, **kwargs)

def plot_jumps_3d(self, **kwargs):
@plot_backend
def plot_jumps_3d(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_3d][] for more information."""
from gemdat import plots
return module.jumps_3d(jumps=self, **kwargs)

return plots.jumps_3d(jumps=self, **kwargs)

def plot_jumps_3d_animation(self, **kwargs):
@plot_backend
def plot_jumps_3d_animation(self, *, module, **kwargs):
"""See [gemdat.plots.jumps_3d_animation][] for more information."""
from gemdat import plots

return plots.jumps_3d_animation(jumps=self, **kwargs)
return module.jumps_3d_animation(jumps=self, **kwargs)
28 changes: 15 additions & 13 deletions src/gemdat/orientations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

from dataclasses import InitVar, dataclass, field, replace
from typing import TYPE_CHECKING

import numpy as np
from pymatgen.symmetry.groups import PointGroup

from gemdat.trajectory import Trajectory
from gemdat.utils import cartesian_to_spherical, fft_autocorrelation

from ._plot_backend import plot_backend

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


@dataclass
class Orientations:
Expand Down Expand Up @@ -259,23 +264,20 @@ def autocorrelation(self):
"""Compute the autocorrelation of the orientation vectors using FFT."""
return fft_autocorrelation(self.vectors)

def plot_rectilinear(self, **kwargs):
@plot_backend
def plot_rectilinear(self, *, module, **kwargs):
"""See [gemdat.plots.rectilinear][] for more info."""
from gemdat import plots

return plots.rectilinear(orientations=self, **kwargs)
return module.rectilinear(orientations=self, **kwargs)

def plot_bond_length_distribution(self, **kwargs):
@plot_backend
def plot_bond_length_distribution(self, *, module, **kwargs):
"""See [gemdat.plots.bond_length_distribution][] for more info."""
from gemdat import plots
return module.bond_length_distribution(orientations=self, **kwargs)

return plots.bond_length_distribution(orientations=self, **kwargs)

def plot_autocorrelation(self, **kwargs):
@plot_backend
def plot_autocorrelation(self, *, module, **kwargs):
"""See [gemdat.plots.unit_vector_autocorrelation][] for more info."""
from gemdat import plots

return plots.autocorrelation(orientations=self, **kwargs)
return module.autocorrelation(orientations=self, **kwargs)


def calculate_spherical_areas(shape: tuple[int, int], radius: float = 1) -> np.ndarray:
Expand Down
15 changes: 7 additions & 8 deletions src/gemdat/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from gemdat.volume import FreeEnergyVolume

from ._plot_backend import plot_backend
from .utils import nearest_structure_reference

if TYPE_CHECKING:
Expand Down Expand Up @@ -142,17 +143,15 @@ def stop_site(self) -> tuple[int, int, int]:
"""Return stop site."""
return self.sites[-1]

def plot_energy_along_path(self, **kwargs):
@plot_backend
def plot_energy_along_path(self, module, **kwargs):
"""See [gemdat.plots.energy_along_path][] for more info."""
from gemdat import plots
return module.energy_along_path(path=self, **kwargs)

return plots.energy_along_path(path=self, **kwargs)

def plot_path_on_grid(self, **kwargs):
@plot_backend
def plot_path_on_grid(self, module, **kwargs):
"""See [gemdat.plots.path_on_grid][] for more info."""
from gemdat import plots

return plots.path_on_grid(path=self, **kwargs)
return module.path_on_grid(path=self, **kwargs)


def free_energy_graph(
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_autocorrelation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.orientations import Orientations
if TYPE_CHECKING:
from gemdat.orientations import Orientations


def autocorrelation(
Expand Down
7 changes: 5 additions & 2 deletions src/gemdat/plots/matplotlib/_bond_length_distribution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import matplotlib.pyplot as plt
from typing import TYPE_CHECKING

from gemdat.orientations import Orientations
import matplotlib.pyplot as plt

from .._shared import _fit_skewnorm_to_hist, _orientations_to_histogram

if TYPE_CHECKING:
from gemdat.orientations import Orientations


def bond_length_distribution(*, orientations: Orientations, bins: int = 50) -> plt.Figure:
"""Plot the bond length probability distribution.
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_histogram.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def displacement_histogram(trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_per_atom.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
6 changes: 5 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_per_element.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from gemdat.plots._shared import _mean_displacements_per_element
from gemdat.trajectory import Trajectory

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_energy_along_path.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
from pymatgen.core import Structure

from gemdat.path import Pathway
if TYPE_CHECKING:
from gemdat.path import Pathway


def energy_along_path(
Expand Down
6 changes: 5 additions & 1 deletion src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.simulation_metrics import SimulationMetrics
from gemdat.trajectory import Trajectory

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/matplotlib/_msd_per_element.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def msd_per_element(
Expand Down
10 changes: 6 additions & 4 deletions src/gemdat/plots/matplotlib/_rectilinear.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from gemdat.orientations import (
Orientations,
calculate_spherical_areas,
)
if TYPE_CHECKING:
from gemdat.orientations import Orientations


def rectilinear(
Expand All @@ -32,6 +32,8 @@ def rectilinear(
fig : matplotlib.figure.Figure
Output figure
"""
from gemdat.orientations import calculate_spherical_areas

az, el, _ = orientations.vectors_spherical.T
az = az.flatten()
el = el.flatten()
Expand Down
6 changes: 5 additions & 1 deletion src/gemdat/plots/matplotlib/_vibrational_amplitudes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

from gemdat.simulation_metrics import SimulationMetrics
from gemdat.trajectory import Trajectory

if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/plotly/_autocorrelation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import plotly.graph_objects as go

from gemdat.orientations import Orientations
if TYPE_CHECKING:
from gemdat.orientations import Orientations


def autocorrelation(
Expand Down
7 changes: 5 additions & 2 deletions src/gemdat/plots/plotly/_bond_length_distribution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import plotly.express as px
import plotly.graph_objects as go

from gemdat.orientations import Orientations

from .._shared import _fit_skewnorm_to_hist, _orientations_to_histogram

if TYPE_CHECKING:
from gemdat.orientations import Orientations


def bond_length_distribution(*, orientations: Orientations, bins: int = 50) -> go.Figure:
"""Plot the bond length probability distribution.
Expand Down
5 changes: 4 additions & 1 deletion src/gemdat/plots/plotly/_displacement_histogram.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from gemdat.trajectory import Trajectory
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def _trajectory_to_dataframe(trajectory: Trajectory) -> pd.DataFrame:
Expand Down
Loading
Loading