Skip to content

Commit

Permalink
feat: element label in plot (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
kahojyun authored Nov 6, 2024
1 parent 479b92e commit a2025a0
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 53 deletions.
13 changes: 2 additions & 11 deletions examples/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@

from bosing import Barrier, Play, ShiftPhase, Stack

xy = [Play(f"xy{i}", "hann", 1.0, 100e-9) for i in range(2)]
xy = [Play(f"xy{i}", "hann", 1.0, 100e-9, label=f"xy{i}") for i in range(2)]
z = [
Stack(Play(f"z{i}", "hann", 1.0, 100e-9), ShiftPhase(f"xy{i}", 1.0))
for i in range(2)
]
m = Stack(*(Play(f"m{i}", "hann", 1.0, 100e-9, plateau=200e-9) for i in range(2)))
b = Barrier()

schedule = Stack(
xy[0],
xy[1],
b,
z[1],
b,
xy[1],
b,
m,
)
schedule = Stack(xy[0], xy[1], b, z[1], b, xy[1], b, m, label="root")

schedule.plot()
plt.show()
36 changes: 33 additions & 3 deletions python/bosing/_bosing.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ruff: noqa: PLR0913
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, ClassVar, Literal, final, type_check_only
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, ClassVar, Literal, final

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -34,6 +34,8 @@ __all__ = [
"generate_waveforms",
"generate_waveforms_with_states",
"ItemKind",
"PlotArgs",
"PlotItem",
]

_RichReprResult: TypeAlias = list[Any]
Expand Down Expand Up @@ -124,13 +126,16 @@ class Element:
def max_duration(self) -> float: ...
@property
def min_duration(self) -> float: ...
@property
def label(self) -> str: ...
def measure(self) -> float: ...
def plot(
self,
ax: Axes | None = ...,
*,
channels: Sequence[str] | None = ...,
max_depth: int = ...,
show_label: bool = ...,
) -> Axes: ...

@final
Expand All @@ -155,6 +160,7 @@ class Play(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def channel_id(self) -> str: ...
Expand Down Expand Up @@ -191,6 +197,7 @@ class ShiftPhase(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def channel_id(self) -> str: ...
Expand All @@ -213,6 +220,7 @@ class SetPhase(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def channel_id(self) -> str: ...
Expand All @@ -235,6 +243,7 @@ class ShiftFreq(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def channel_id(self) -> str: ...
Expand All @@ -257,6 +266,7 @@ class SetFreq(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def channel_id(self) -> str: ...
Expand All @@ -279,6 +289,7 @@ class SwapPhase(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def channel_id1(self) -> str: ...
Expand All @@ -299,6 +310,7 @@ class Barrier(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def channel_ids(self) -> Sequence[str]: ...
Expand All @@ -320,6 +332,7 @@ class Repeat(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
@property
def child(self) -> Element: ...
Expand Down Expand Up @@ -350,6 +363,7 @@ class Stack(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
def with_children(self, *children: Element) -> Stack: ...
@property
Expand Down Expand Up @@ -384,6 +398,7 @@ class Absolute(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
def with_children(self, *children: _AbsoluteEntryLike) -> Absolute: ...
@property
Expand Down Expand Up @@ -444,6 +459,7 @@ class Grid(Element):
duration: float | None = ...,
max_duration: float = ...,
min_duration: float = ...,
label: str | None = ...,
) -> Self: ...
def with_children(
self,
Expand Down Expand Up @@ -471,7 +487,19 @@ class OscState:
def with_time_shift(self, time: float) -> Self: ...
def __rich_repr__(self) -> _RichReprResult: ... # undocumented

@type_check_only
@final
class PlotArgs:
@property
def ax(self) -> Axes | None: ...
@property
def blocks(self) -> Iterator[PlotItem]: ...
@property
def channels(self) -> list[str]: ...
@property
def max_depth(self) -> int: ...
@property
def show_label(self) -> bool: ...

@final
class PlotItem:
@property
Expand All @@ -484,6 +512,8 @@ class PlotItem:
def depth(self) -> int: ...
@property
def kind(self) -> ItemKind: ...
@property
def label(self) -> str | None: ...

@final
class ItemKind:
Expand Down
34 changes: 27 additions & 7 deletions python/bosing/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patheffects
from matplotlib.patches import PathPatch
from matplotlib.path import Path
from matplotlib.ticker import EngFormatter
Expand All @@ -17,10 +18,11 @@
from matplotlib.axes import Axes
from typing_extensions import TypeAlias

from bosing._bosing import PlotItem
from bosing._bosing import PlotArgs, PlotItem

_RECTS: TypeAlias = defaultdict[ItemKind, list[tuple[float, float, float]]]
_MARKERS: TypeAlias = defaultdict[ItemKind, tuple[list[float], list[float]]]
_TEXTS: TypeAlias = list[tuple[float, float, str]]


COLORS = {
Expand Down Expand Up @@ -87,10 +89,11 @@ def process_blocks(
channels: Sequence[str],
max_depth: int,
channels_ystart: dict[str, int],
) -> tuple[_RECTS, _MARKERS]:
) -> tuple[_RECTS, _MARKERS, _TEXTS]:
ch_stack: list[list[str]] = []
rects: _RECTS = defaultdict(list)
markers: _MARKERS = defaultdict(lambda: ([], []))
texts: _TEXTS = []

for x in blocks:
manage_channel_stack(ch_stack, x)
Expand All @@ -105,17 +108,23 @@ def process_blocks(
my.append(y)
else:
rects[x.kind].append((x.start, y, x.span))
return rects, markers
if x.label is not None:
texts.append((x.start, y, x.label))
return rects, markers, texts


def plot(
ax: Axes | None, blocks: Iterator[PlotItem], channels: Sequence[str], max_depth: int
) -> Axes:
def plot(args: PlotArgs) -> Axes:
ax = args.ax
blocks = args.blocks
channels = args.channels
max_depth = args.max_depth
show_label = args.show_label

if ax is None:
ax = plt.gca()

channels_ystart = {c: i * (max_depth + 1) for i, c in enumerate(channels)}
rects, markers = process_blocks(blocks, channels, max_depth, channels_ystart)
rects, markers, texts = process_blocks(blocks, channels, max_depth, channels_ystart)

for k, r in rects.items():
# numrects x [x, y, width]
Expand Down Expand Up @@ -143,6 +152,17 @@ def plot(
markersize=12,
)

if show_label:
for x, y, label in texts:
txt = ax.annotate(label, (x, y)) # pyright: ignore[reportUnknownMemberType]
# Add white outline to text
txt.set_path_effects(
[
patheffects.Stroke(linewidth=2, foreground="white"),
patheffects.Normal(),
]
)

_ = ax.set_yticks(list(channels_ystart.values()), channels_ystart.keys()) # pyright: ignore[reportUnknownMemberType]
ax.xaxis.set_major_formatter(EngFormatter(places=3))
_ = ax.set_xlabel("Time") # pyright: ignore[reportUnknownMemberType]
Expand Down
2 changes: 1 addition & 1 deletion src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod export {
GridLength, GridLengthUnit, Play, Repeat, SetFreq, SetPhase, ShiftFreq, ShiftPhase,
Stack, SwapPhase,
},
plot::ItemKind,
plot::{ItemKind, PlotArgs, PlotItem},
shapes::{Hann, Interp, Shape},
wavegen::{generate_waveforms, generate_waveforms_with_states, Channel, OscState},
};
Expand Down
Loading

0 comments on commit a2025a0

Please sign in to comment.