From a2025a074fa33d37206542379972a0cd96827e2a Mon Sep 17 00:00:00 2001 From: Jiahao Yuan Date: Wed, 6 Nov 2024 21:03:03 +0800 Subject: [PATCH] feat: element label in plot (#221) --- examples/plot.py | 13 +----- python/bosing/_bosing.pyi | 36 ++++++++++++-- python/bosing/_plot.py | 34 +++++++++++--- src/python.rs | 2 +- src/python/elements.rs | 83 ++++++++++++++++++++++++--------- src/python/elements/absolute.rs | 5 +- src/python/elements/grid.rs | 5 +- src/python/elements/stack.rs | 5 +- src/python/plot.rs | 30 ++++++++++-- src/quant.rs | 2 + src/schedule.rs | 13 +++++- stubtest-allowlist.txt | 2 + 12 files changed, 177 insertions(+), 53 deletions(-) diff --git a/examples/plot.py b/examples/plot.py index a56ec09..58bc59d 100644 --- a/examples/plot.py +++ b/examples/plot.py @@ -2,7 +2,7 @@ from bosing import Barrier, Play, ShiftPhase, Stack -xy = [Play(f"xy{i}", "hann", 1.0, 100e-9) for i in range(2)] +xy = [Play(f"xy{i}", "hann", 1.0, 100e-9, label=f"xy{i}") for i in range(2)] z = [ Stack(Play(f"z{i}", "hann", 1.0, 100e-9), ShiftPhase(f"xy{i}", 1.0)) for i in range(2) @@ -10,16 +10,7 @@ m = Stack(*(Play(f"m{i}", "hann", 1.0, 100e-9, plateau=200e-9) for i in range(2))) b = Barrier() -schedule = Stack( - xy[0], - xy[1], - b, - z[1], - b, - xy[1], - b, - m, -) +schedule = Stack(xy[0], xy[1], b, z[1], b, xy[1], b, m, label="root") schedule.plot() plt.show() diff --git a/python/bosing/_bosing.pyi b/python/bosing/_bosing.pyi index 5c6c5f6..18310a9 100644 --- a/python/bosing/_bosing.pyi +++ b/python/bosing/_bosing.pyi @@ -1,6 +1,6 @@ # ruff: noqa: PLR0913 -from collections.abc import Iterable, Mapping, Sequence -from typing import Any, ClassVar, Literal, final, type_check_only +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import Any, ClassVar, Literal, final import numpy as np import numpy.typing as npt @@ -34,6 +34,8 @@ __all__ = [ "generate_waveforms", "generate_waveforms_with_states", "ItemKind", + "PlotArgs", + "PlotItem", ] _RichReprResult: TypeAlias = list[Any] @@ -124,6 +126,8 @@ class Element: def max_duration(self) -> float: ... @property def min_duration(self) -> float: ... + @property + def label(self) -> str: ... def measure(self) -> float: ... def plot( self, @@ -131,6 +135,7 @@ class Element: *, channels: Sequence[str] | None = ..., max_depth: int = ..., + show_label: bool = ..., ) -> Axes: ... @final @@ -155,6 +160,7 @@ class Play(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def channel_id(self) -> str: ... @@ -191,6 +197,7 @@ class ShiftPhase(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def channel_id(self) -> str: ... @@ -213,6 +220,7 @@ class SetPhase(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def channel_id(self) -> str: ... @@ -235,6 +243,7 @@ class ShiftFreq(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def channel_id(self) -> str: ... @@ -257,6 +266,7 @@ class SetFreq(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def channel_id(self) -> str: ... @@ -279,6 +289,7 @@ class SwapPhase(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def channel_id1(self) -> str: ... @@ -299,6 +310,7 @@ class Barrier(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def channel_ids(self) -> Sequence[str]: ... @@ -320,6 +332,7 @@ class Repeat(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... @property def child(self) -> Element: ... @@ -350,6 +363,7 @@ class Stack(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... def with_children(self, *children: Element) -> Stack: ... @property @@ -384,6 +398,7 @@ class Absolute(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... def with_children(self, *children: _AbsoluteEntryLike) -> Absolute: ... @property @@ -444,6 +459,7 @@ class Grid(Element): duration: float | None = ..., max_duration: float = ..., min_duration: float = ..., + label: str | None = ..., ) -> Self: ... def with_children( self, @@ -471,7 +487,19 @@ class OscState: def with_time_shift(self, time: float) -> Self: ... def __rich_repr__(self) -> _RichReprResult: ... # undocumented -@type_check_only +@final +class PlotArgs: + @property + def ax(self) -> Axes | None: ... + @property + def blocks(self) -> Iterator[PlotItem]: ... + @property + def channels(self) -> list[str]: ... + @property + def max_depth(self) -> int: ... + @property + def show_label(self) -> bool: ... + @final class PlotItem: @property @@ -484,6 +512,8 @@ class PlotItem: def depth(self) -> int: ... @property def kind(self) -> ItemKind: ... + @property + def label(self) -> str | None: ... @final class ItemKind: diff --git a/python/bosing/_plot.py b/python/bosing/_plot.py index ef75f76..5105e01 100644 --- a/python/bosing/_plot.py +++ b/python/bosing/_plot.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import numpy as np +from matplotlib import patheffects from matplotlib.patches import PathPatch from matplotlib.path import Path from matplotlib.ticker import EngFormatter @@ -17,10 +18,11 @@ from matplotlib.axes import Axes from typing_extensions import TypeAlias - from bosing._bosing import PlotItem + from bosing._bosing import PlotArgs, PlotItem _RECTS: TypeAlias = defaultdict[ItemKind, list[tuple[float, float, float]]] _MARKERS: TypeAlias = defaultdict[ItemKind, tuple[list[float], list[float]]] + _TEXTS: TypeAlias = list[tuple[float, float, str]] COLORS = { @@ -87,10 +89,11 @@ def process_blocks( channels: Sequence[str], max_depth: int, channels_ystart: dict[str, int], -) -> tuple[_RECTS, _MARKERS]: +) -> tuple[_RECTS, _MARKERS, _TEXTS]: ch_stack: list[list[str]] = [] rects: _RECTS = defaultdict(list) markers: _MARKERS = defaultdict(lambda: ([], [])) + texts: _TEXTS = [] for x in blocks: manage_channel_stack(ch_stack, x) @@ -105,17 +108,23 @@ def process_blocks( my.append(y) else: rects[x.kind].append((x.start, y, x.span)) - return rects, markers + if x.label is not None: + texts.append((x.start, y, x.label)) + return rects, markers, texts -def plot( - ax: Axes | None, blocks: Iterator[PlotItem], channels: Sequence[str], max_depth: int -) -> Axes: +def plot(args: PlotArgs) -> Axes: + ax = args.ax + blocks = args.blocks + channels = args.channels + max_depth = args.max_depth + show_label = args.show_label + if ax is None: ax = plt.gca() channels_ystart = {c: i * (max_depth + 1) for i, c in enumerate(channels)} - rects, markers = process_blocks(blocks, channels, max_depth, channels_ystart) + rects, markers, texts = process_blocks(blocks, channels, max_depth, channels_ystart) for k, r in rects.items(): # numrects x [x, y, width] @@ -143,6 +152,17 @@ def plot( markersize=12, ) + if show_label: + for x, y, label in texts: + txt = ax.annotate(label, (x, y)) # pyright: ignore[reportUnknownMemberType] + # Add white outline to text + txt.set_path_effects( + [ + patheffects.Stroke(linewidth=2, foreground="white"), + patheffects.Normal(), + ] + ) + _ = ax.set_yticks(list(channels_ystart.values()), channels_ystart.keys()) # pyright: ignore[reportUnknownMemberType] ax.xaxis.set_major_formatter(EngFormatter(places=3)) _ = ax.set_xlabel("Time") # pyright: ignore[reportUnknownMemberType] diff --git a/src/python.rs b/src/python.rs index d0faff3..ae5bd8d 100644 --- a/src/python.rs +++ b/src/python.rs @@ -17,7 +17,7 @@ mod export { GridLength, GridLengthUnit, Play, Repeat, SetFreq, SetPhase, ShiftFreq, ShiftPhase, Stack, SwapPhase, }, - plot::ItemKind, + plot::{ItemKind, PlotArgs, PlotItem}, shapes::{Hann, Interp, Shape}, wavegen::{generate_waveforms, generate_waveforms_with_states, Channel, OscState}, }; diff --git a/src/python/elements.rs b/src/python/elements.rs index 51a31bd..981ed82 100644 --- a/src/python/elements.rs +++ b/src/python/elements.rs @@ -7,7 +7,7 @@ use std::{borrow::Borrow as _, fmt::Debug, sync::Arc}; use pyo3::{exceptions::PyValueError, prelude::*, types::DerefToPyAny}; use crate::{ - quant::{Amplitude, ChannelId, Frequency, Phase, ShapeId, Time}, + quant::{Amplitude, ChannelId, Frequency, Label, Phase, ShapeId, Time}, schedule::{ self, ElementCommon, ElementCommonBuilder, ElementRef, ElementVariant, Measure as _, }, @@ -87,24 +87,6 @@ impl Alignment { } } -fn extract_alignment(obj: &Bound) -> PyResult { - Alignment::convert(obj).and_then(|x| x.extract(obj.py())) -} - -fn extract_margin(obj: &Bound) -> PyResult<(Time, Time)> { - if let Ok(v) = obj.extract() { - let t = Time::new(v)?; - return Ok((t, t)); - } - if let Ok((v1, v2)) = obj.extract() { - let t1 = Time::new(v1)?; - let t2 = Time::new(v2)?; - return Ok((t1, t2)); - } - let msg = "Failed to convert the value to (float, float)."; - Err(PyValueError::new_err(msg)) -} - /// Base class for schedule elements. /// /// A schedule element is a node in the tree structure of a schedule similar to @@ -179,6 +161,7 @@ fn extract_margin(obj: &Bound) -> PyResult<(Time, Time)> { /// max_duration (float): Maximum duration of the element. Defaults to /// ``inf``. /// min_duration (float): Minimum duration of the element. Defaults to ``0``. +/// label (str | None): Label of the element. Defaults to ``None``. #[pyclass(module = "bosing", subclass, frozen)] #[derive(Debug, Clone)] pub(crate) struct Element(pub(super) ElementRef); @@ -215,6 +198,11 @@ impl Element { self.0.common.min_duration() } + #[getter] + fn label(&self) -> Option<&Label> { + self.0.common.label() + } + /// Measure the minimum total duration required by the element. /// /// This value includes both inner `duration` and outer `margin` of the element. @@ -234,18 +222,20 @@ impl Element { /// channels (Sequence[str] | None): Channels to plot. If ``None``, all channels are /// plotted. /// max_depth (int): Maximum depth to plot. Defaults to ``5``. + /// show_label (bool): Whether to show label of elements. Defaults to ``True``. /// /// Returns: /// matplotlib.axes.Axes: Axes with the plot. - #[pyo3(signature = (ax=None, *, channels=None, max_depth=5))] + #[pyo3(signature = (ax=None, *, channels=None, max_depth=5, show_label=true))] fn plot( &self, py: Python, ax: Option, channels: Option>, max_depth: usize, + show_label: bool, ) -> PyResult { - plot_element(py, self.0.clone(), ax, channels, max_depth) + plot_element(py, self.0.clone(), ax, channels, max_depth, show_label) } } @@ -278,6 +268,7 @@ where .expect("Element should have a valid variant") } + #[allow(clippy::too_many_arguments)] fn build_element( variant: Self::Variant, margin: Option<&Bound>, @@ -286,6 +277,7 @@ where duration: Option