Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor spline properties to the property props #54

Merged
merged 6 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions cylindra/_custom_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from contextlib import contextmanager

from typing import TYPE_CHECKING, Any, NamedTuple
import weakref
Expand Down Expand Up @@ -120,6 +121,8 @@ def __init__(self, data: Molecules, **kwargs):
self._molecules = data
self._colormap_info: ColormapInfo | None = None
self._source_component: weakref.ReferenceType[BaseComponent] | None = None
self._old_name: str | None = None # for undo/redo
self._undo_renaming = False
super().__init__(data.pos, **kwargs)
features = data.features
if features is not None and len(features) > 0:
Expand Down Expand Up @@ -151,6 +154,33 @@ def features(self, features):
Points.features.fset(self, df)
self._molecules.features = df

@property
def name(self) -> str:
return super().name

@name.setter
def name(self, name: str) -> None:
if self.name == name:
return None
self._old_name = self.name
if not name:
name = self._basename()
self._name = str(name)
self.events.name()

@contextmanager
def _undo_context(self):
was_renaming = self._undo_renaming
self._undo_renaming = True
try:
yield
finally:
self._undo_renaming = was_renaming

def _rename(self, name: str):
with self._undo_context():
self.name = name

@property
def source_component(self) -> BaseComponent | None:
"""The source tomographic component object."""
Expand All @@ -171,6 +201,7 @@ def source_component(self, obj: BaseComponent | None):

@property
def source_spline(self) -> CylSpline | None:
"""The source component but limited to splines."""
from cylindra.components import CylSpline

src = self.source_component
Expand Down
36 changes: 15 additions & 21 deletions cylindra/components/cyl_spline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from typing import Callable, Any, TYPE_CHECKING
from typing import Any
import numpy as np
from numpy.typing import ArrayLike
import polars as pl

from .spline import Spline
Expand All @@ -14,32 +13,29 @@
from cylindra.utils import roundint
from cylindra.components.cylindric import CylinderModel

if TYPE_CHECKING:
Degenerative = Callable[[ArrayLike], Any]


class CylSpline(Spline):
"""A spline object with cylindrical structure."""

@property
def radius(self) -> nm | None:
"""Average radius of the cylinder."""
return self.get_globalprops(H.radius, None)
return self.props.get_glob(H.radius, None)

@radius.setter
def radius(self, value: nm | None):
if value is None:
if H.radius in self.globalprops.columns:
self.drop_globalprops(H.radius)
self.props.drop_glob(H.radius)
return None
col = pl.Series(H.radius, [value]).cast(pl.Float32)
self._globalprops = self.globalprops.with_columns(col)
self.props.glob = self.props.glob.with_columns(col)
return None

@property
def orientation(self) -> Ori:
"""Orientation of the spline."""
return Ori(str(self.get_globalprops(H.orientation, "none")))
return Ori(str(self.props.get_glob(H.orientation, "none")))

@orientation.setter
def orientation(self, value: Ori | str | None):
Expand All @@ -48,7 +44,7 @@ def orientation(self, value: Ori | str | None):
else:
value = Ori(value)
col = pl.Series(H.orientation, [str(value)])
self._globalprops = self.globalprops.with_columns(col)
self.props.glob = self.props.glob.with_columns(col)
return None

def invert(self) -> CylSpline:
Expand All @@ -62,11 +58,9 @@ def invert(self) -> CylSpline:
"""
# NOTE: invert() calls clip() internally.
# We don't have to invert the orientation here.
return (
super()
.invert()
.update_localprops(self.localprops[::-1], self.localprops_window_size)
)
new = super().invert()
new.props.update_loc(self.props.loc[::-1], self.props.window_size)
return new

def clip(self, start: float, stop: float) -> CylSpline:
"""
Expand All @@ -91,7 +85,7 @@ def clip(self, start: float, stop: float) -> CylSpline:
"""
clipped = super().clip(start, stop)

clipped._globalprops = self.globalprops.clone()
clipped.props.glob = self.props.glob.clone()
if start > stop:
clipped.orientation = Ori.invert(self.orientation)
else:
Expand Down Expand Up @@ -227,7 +221,7 @@ def update_props(
# update H.start
if rise is not None:
r = radius if radius is not None else self.radius
if r is not None and self.has_localprops([H.rise, H.spacing, H.skew]):
if r is not None and self.props.has_loc([H.rise, H.spacing, H.skew]):
_start_loc = rise_to_start(
rise=np.deg2rad(ldf[H.rise].to_numpy()),
space=ldf[H.spacing].to_numpy(),
Expand All @@ -237,7 +231,7 @@ def update_props(
ldf = ldf.with_columns(
pl.Series(_start_loc).cast(pl.Float32).alias(H.start)
)
if r is not None and self.has_globalprops([H.rise, H.spacing, H.skew]):
if r is not None and self.props.has_glob([H.rise, H.spacing, H.skew]):
_start_glob = rise_to_start(
rise=np.deg2rad(gdf[H.rise].to_numpy()),
space=gdf[H.spacing].to_numpy(),
Expand All @@ -248,8 +242,8 @@ def update_props(
pl.Series(_start_glob).cast(pl.Float32).alias(H.start)
)

self._localprops = ldf
self._globalprops = gdf
self.props.loc = ldf
self.props.glob = gdf
return self

def _need_rotation(self, orientation: Ori | str | None) -> bool:
Expand All @@ -274,4 +268,4 @@ def rise_to_start(rise: float, space: nm, skew: float, perimeter: nm) -> float:
def _get_globalprops(spl: CylSpline, kwargs: dict[str, Any], name: str):
if name in kwargs:
return kwargs[name]
return spl.get_globalprops(name)
return spl.props.get_glob(name)
37 changes: 19 additions & 18 deletions cylindra/components/cyl_tomogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def refine(
self.set_radius(i=i, binsize=binsize)

_required = [H.spacing, H.skew, H.nPF]
if not spl.has_globalprops(_required):
if not spl.props.has_glob(_required):
self.global_ft_params(i=i, binsize=binsize)

spl.make_anchors(max_interval=max_interval)
Expand All @@ -529,9 +529,9 @@ def refine(
# Calculate Fourier parameters by cylindrical transformation along spline.
# Skew angles are divided by the angle of single protofilament and the residual
# angles are used, considering missing wedge effect.
interv = spl.get_globalprops(H.spacing) * 2
skew = spl.get_globalprops(H.skew)
npf = roundint(spl.get_globalprops(H.nPF))
interv = spl.props.get_glob(H.spacing) * 2
skew = spl.props.get_glob(H.skew)
npf = roundint(spl.props.get_glob(H.nPF))

LOGGER.info(
f" >> Parameters: spacing = {interv/2:.2f} nm, skew = {skew:.3f} deg, PF = {npf}"
Expand Down Expand Up @@ -754,7 +754,7 @@ def local_radii(
r_peak_sub = (imax_sub + 0.5) / nbin * r_max
radii.append(r_peak_sub)
out = pl.Series(H.radius, radii, dtype=pl.Float32)
spl.update_localprops([out], size)
spl.props.update_loc([out], size)
return out

@batch_process
Expand Down Expand Up @@ -814,8 +814,7 @@ def local_ft_params(
pl.Series(H.splPos, spl.anchors, dtype=pl.Float32),
pl.Series(H.splDist, spl.distances(), dtype=pl.Float32),
)

spl.update_localprops(lprops, ft_size)
spl.props.update_loc(lprops, ft_size)

return lprops

Expand Down Expand Up @@ -1017,7 +1016,7 @@ def infer_polarity(
polar = mask_spectra(polar)
img_flat = polar.proj("y")

if (npf := spl.get_globalprops(H.nPF, None)) is None:
if (npf := spl.props.get_glob(H.nPF, None)) is None:
# if the global properties are already calculated, use it
# otherwise, calculate the number of PFs from the power spectrum
ft = img_flat.fft(shift=False, dims="ra")
Expand Down Expand Up @@ -1245,11 +1244,11 @@ def map_centers(
Molecules object with mapped coordinates and angles.
"""
spl = self.splines[i]
if spl.has_globalprops([H.spacing, H.skew]):
if spl.props.has_glob([H.spacing, H.skew]):
self.global_ft_params(i=i)

interv = spl.get_globalprops(H.spacing) * 2
skew = spl.get_globalprops(H.skew)
interv = spl.props.get_glob(H.spacing) * 2
skew = spl.props.get_glob(H.skew)

# Set interval to the dimer length by default.
if interval is None:
Expand Down Expand Up @@ -1291,7 +1290,7 @@ def get_cylinder_model(
spl = self.splines[i]
_required = [H.spacing, H.skew, H.rise, H.nPF]
_missing = [k for k in _required if k not in kwargs]
if not spl.has_globalprops(_missing):
if not spl.props.has_glob(_missing):
self.global_ft_params(i=i)
return spl.cylinder_model(offsets=offsets, **kwargs)

Expand All @@ -1302,6 +1301,7 @@ def map_monomers(
*,
offsets: tuple[nm, float] | None = None,
orientation: Ori | str | None = None,
**kwargs,
) -> Molecules:
"""
Map monomers in a regular cylinder shape.
Expand All @@ -1320,7 +1320,7 @@ def map_monomers(
Molecules
Object that represents monomer positions and angles.
"""
model = self.get_cylinder_model(i, offsets=offsets)
model = self.get_cylinder_model(i, offsets=offsets, **kwargs)
yy, aa = np.indices(model.shape, dtype=np.int32)
coords = np.stack([yy.ravel(), aa.ravel()], axis=1)
spl = self.splines[i]
Expand All @@ -1337,6 +1337,7 @@ def map_on_grid(
*,
offsets: tuple[nm, float] | None = None,
orientation: Ori | str | None = None,
**kwargs,
) -> Molecules:
"""
Map monomers in a regular cylinder shape.
Expand All @@ -1357,7 +1358,7 @@ def map_on_grid(
Molecules
Object that represents monomer positions and angles.
"""
model = self.get_cylinder_model(i, offsets=offsets)
model = self.get_cylinder_model(i, offsets=offsets, **kwargs)
coords = np.asarray(coords, dtype=np.int32)
spl = self.splines[i]
mole = model.locate_molecules(spl, coords)
Expand Down Expand Up @@ -1394,10 +1395,10 @@ def map_pf_line(
Object that represents protofilament positions and angles.
"""
spl = self.splines[i]
if not spl.has_globalprops([H.spacing, H.skew]):
if not spl.props.has_glob([H.spacing, H.skew]):
self.global_ft_params(i=i)
interv = spl.get_globalprops(H.spacing) * 2
skew = spl.get_globalprops(H.skew)
interv = spl.props.get_glob(H.spacing) * 2
skew = spl.props.get_glob(H.skew)

if interval is None:
interval = interv
Expand Down Expand Up @@ -1555,7 +1556,7 @@ def _prepare_radii(
raise ValueError("Global radius is not measured yet.")
radii = np.full(spl.anchors.size, spl.radius, dtype=np.float32)
elif radius == "local":
if not spl.has_localprops(H.radius):
if not spl.props.has_loc(H.radius):
raise ValueError("Local radii is not measured yet.")
radii = spl.localprops[H.radius].to_numpy()
else:
Expand Down
Loading