Skip to content

Commit

Permalink
Refactor axes
Browse files Browse the repository at this point in the history
  • Loading branch information
pheuer committed Dec 16, 2024
1 parent fabd5ff commit bd9922e
Showing 1 changed file with 134 additions and 73 deletions.
207 changes: 134 additions & 73 deletions src/cr39py/scan/base_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from cr39py.scan.cut import Cut
from cr39py.scan.subset import Subset

__all__ = ["Scan"]
__all__ = ["Axis", "Scan"]


class _Axis(ExportableClassMixin):
class Axis(ExportableClassMixin):
"""Represents an axis of a CR-39 scan"""

_exportable_attributes = ["ind", "_unit", "_default_range", "framesize"]

Expand All @@ -35,7 +36,12 @@ def __init__(self, ind=None, unit=None, default_range=(None, None, None)) -> Non
self._ind = ind
self._unit = unit
self._default_range = default_range
self.framesize = None

# Framesize is mutable
self._framesize = None

# Tracks over which to calculate axes
self._tracks = None

@property
def ind(self) -> int:
Expand All @@ -48,39 +54,74 @@ def unit(self):
@property
def default_range(self):
"""
Default range is (min, max, framesize).
Default range (min, max, framesize) for this axis.
None means set based on data bounds
"""
return self._default_range

def setup(self, tracks: TrackData) -> None:
@cached_property
def _default_framesize(self) -> u.Quantity:
"""
Setup the axes for the provided track array.
Calculates an initial framesize based on the selected tracks.
"""
# If a default framesize was specified, return that
default_framesize = self.default_range[2]
if default_framesize is not None:
framesize = default_framesize
else:
# Otherwise, determine a framesize that will result in about
# 20 tracks per frame
ntracks = self.tracks.shape[0]
nbins = int(np.clip(np.sqrt(ntracks) / 20, 20, 200))
minval = np.min(self.tracks[:, self.ind])
maxval = np.max(self.tracks[:, self.ind])
framesize = (maxval - minval) / nbins

Parameters
----------
tracks : `~numpy.ndarray` (ntracks,6)
Tracks for which the axis should be initialized.
return framesize * self.unit

def _reset_default_framesize(self) -> None:
"""Resets the default framesize if the tracks change."""
if hasattr(self, "_default_framesize"):
del self._default_framesize

@property
def framesize(self) -> u.Quantity:
"""Frame (bin) size for this axis.
If framesize property is set, returns that value,
otherwise returns a default framesize estimated from
the current tracks.
Returns
-------
u.Quantity
_description_
"""
self._init_framesize(tracks)
if self._framesize is not None:
return self._framesize
else:
return self._default_framesize

def _init_framesize(self, tracks: TrackData) -> None:
@framesize.setter
def framesize(self, framesize: u.Quantity) -> None:
self._framesize = framesize
self._reset_axis()

@property
def tracks(self) -> TrackData:
"""
Calculates an initial framesize.
Tracks associated with this axis.
"""
framesize = self.default_range[2]
ntracks = tracks.shape[0]

if framesize is None:
nbins = int(np.clip(np.sqrt(ntracks) / 20, 20, 200))
minval = np.min(tracks[:, self.ind])
maxval = np.max(tracks[:, self.ind])
framesize = (maxval - minval) / nbins
return self._tracks

self.framesize = framesize * self.unit
@tracks.setter
def tracks(self, tracks: TrackData) -> None:
self._tracks = tracks
self._reset_default_framesize()
self._reset_axis()

def axis(self, tracks: TrackData, units: bool = True) -> np.ndarray | u.Quantity:
@cached_property
def axis(self) -> np.ndarray | u.Quantity:
"""
Axis calculated for the provided array of tracks.
Expand All @@ -90,29 +131,31 @@ def axis(self, tracks: TrackData, units: bool = True) -> np.ndarray | u.Quantity
tracks : `~numpy.ndarray` (ntracks,6)
Tracks for which the axis should be created.
units : bool
If True, return axis as a Quantity. Otherwise
return as a `~numpy.ndarray` in the base units
for this axis.
"""

# Calculate a min and max value for the axis
minval = self.default_range[0]
if minval is None:
minval = np.min(tracks[:, self.ind])
minval = np.min(self.tracks[:, self.ind])

maxval = self.default_range[1]
if maxval is None:
maxval = np.max(tracks[:, self.ind])
maxval = np.max(self.tracks[:, self.ind])

ax = np.arange(minval, maxval, self.framesize.m_as(self.unit))

if units:
ax *= self.unit
ax *= self.unit

return ax

def _reset_axis(self):
"""
Reset the axis to be recalculated if the tracks
or the framesize has changed.
"""
if hasattr(self, "axis"):
del self.axis


class Scan(ExportableClassMixin):
"""
Expand All @@ -128,12 +171,12 @@ class Scan(ExportableClassMixin):
"""

_axes = {
"X": _Axis(ind=0, unit=u.cm, default_range=(None, None, None)),
"Y": _Axis(ind=1, unit=u.cm, default_range=(None, None, None)),
"D": _Axis(ind=2, unit=u.um, default_range=(0, 20, 0.5)),
"C": _Axis(ind=3, unit=u.dimensionless, default_range=(0, 80, 1)),
"E": _Axis(ind=4, unit=u.dimensionless, default_range=(0, 50, 1)),
"Z": _Axis(ind=5, unit=u.um, default_range=(None, None, None)),
"X": Axis(ind=0, unit=u.cm, default_range=(None, None, None)),
"Y": Axis(ind=1, unit=u.cm, default_range=(None, None, None)),
"D": Axis(ind=2, unit=u.um, default_range=(0, 20, 0.5)),
"C": Axis(ind=3, unit=u.dimensionless, default_range=(0, 80, 1)),
"E": Axis(ind=4, unit=u.dimensionless, default_range=(0, 50, 1)),
"Z": Axis(ind=5, unit=u.um, default_range=(None, None, None)),
}

_exportable_attributes = [
Expand All @@ -146,7 +189,7 @@ class Scan(ExportableClassMixin):

def __init__(self) -> None:
self._current_subset_index = 0
self._subsets = []
self._subsets = [Subset()]

self._tracks = None

Expand Down Expand Up @@ -180,6 +223,15 @@ def etch_time(self) -> u.Quantity:
"""
return self._etch_time

@property
def axes(self) -> dict[Axis]:
"""
A dictionary of `~cr39py.scan.base_scan.Axis` objects.
Keys to the dictionary are the axis names, "X","Y","D","C","E","Z".
"""
return self._axes

# **********************************
# Class Methods for initialization
# **********************************
Expand All @@ -202,14 +254,9 @@ def from_tracks(cls, tracks: TrackData, etch_time: float):
obj._etch_time = etch_time * u.min
obj._tracks = tracks

# Initialize the axes based on the provided tracks
# Attach the selected tracks object to the axes objects
for ax in obj._axes.values():
ax.setup(obj._tracks)

# Initialize the list of subsets with a single subset to start.
obj._subsets = [
Subset(),
]
ax.tracks = obj._selected_tracks

return obj

Expand Down Expand Up @@ -479,6 +526,11 @@ def ntracks(self) -> int:
"""
return self._tracks.shape[0]

def _reset_selected_tracks(self):
"""Reset the cached properties associated with _selected_tracks."""
if hasattr(self, "_selected_tracks"):
del self._selected_tracks

@cached_property
def _selected_tracks(self) -> TrackData:
"""
Expand All @@ -489,12 +541,14 @@ def _selected_tracks(self) -> TrackData:
# property if the subset has changed, or if the framesize has
# changed
self._cached_subset_hash = hash(self.current_subset)
return self.current_subset.apply_cuts(self._tracks)

def _reset_selected_tracks(self):
"""Reset the cached selected tracks"""
if hasattr(self, "_selected_tracks"):
del self._selected_tracks
tracks = self.current_subset.apply_cuts(self._tracks)

# Re-attach the new selected tracks to the axes objects
for ax in self._axes.values():
ax.tracks = tracks

return tracks

@property
def selected_tracks(self) -> TrackData:
Expand All @@ -513,6 +567,7 @@ def selected_tracks(self) -> TrackData:
pass
# If not, delete the properties so they will be created again
else:
# Set the selected tracks to be re-generated.
self._reset_selected_tracks()

return self._selected_tracks
Expand Down Expand Up @@ -632,8 +687,6 @@ def histogram(

ax0 = self._axes[axes[0]]
ax1 = self._axes[axes[1]]
ax0_axis = ax0.axis(tracks, units=False)
ax1_axis = ax1.axis(tracks, units=False)

# If creating a histogram like the X,Y,D plots
if quantity is not None:
Expand All @@ -643,17 +696,20 @@ def histogram(
weights = None

rng = [
(np.min(ax0_axis), np.max(ax0_axis)),
(np.min(ax1_axis), np.max(ax1_axis)),
(np.min(ax0.axis.m), np.max(ax0.axis.m)),
(np.min(ax1.axis.m), np.max(ax1.axis.m)),
]
bins = [ax0_axis.size, ax1_axis.size]

arr = histogram2d(
tracks[:, ax0.ind],
tracks[:, ax1.ind],
bins=bins,
range=rng,
weights=weights,
bins = [ax0.axis.size, ax1.axis.size]

arr = (
histogram2d(
tracks[:, ax0.ind],
tracks[:, ax1.ind],
bins=bins,
range=rng,
weights=weights,
)
* u.dimensionless
)

# Create the unweighted histogram and divide by it (sans zeros)
Expand All @@ -666,10 +722,10 @@ def histogram(
)
nz = np.nonzero(arr_uw)
arr[nz] = arr[nz] / arr_uw[nz]
arr = arr * ax2.unit

return ax0_axis, ax1_axis, arr
return ax0.axis, ax1.axis, arr

@property
def overlap_parameter_histogram(self) -> tuple[np.ndarray]:
"""The Zylstra overlap parameter for each cell.
Expand All @@ -691,13 +747,18 @@ def overlap_parameter_histogram(self) -> tuple[np.ndarray]:
x, y, ntracks = self.histogram(axes=("X", "Y"))
x, y, D = self.histogram(axes=("X", "Y"), quantity="D")

print(self._axes["X"].framesize)
print(self._axes["X"].framesize)
print(D.u)
print(ntracks.u)

chi = (
ntracks
/ self._axes["X"].framesize
/ self._axes["Y"].framesize
* np.pi
* D**2
).m_as(u.dimensionless)
).to(u.dimensionless)

return x, y, chi

Expand Down Expand Up @@ -817,12 +878,12 @@ def plot(
title = f"{axes[0]}, {axes[1]}, {quantity}"

# Set any None bounds to the extrema of the ranges
xrange[0] = np.nanmin(xax) if xrange[0] is None else xrange[0]
xrange[1] = np.nanmax(xax) if xrange[1] is None else xrange[1]
yrange[0] = np.nanmin(yax) if yrange[0] is None else yrange[0]
yrange[1] = np.nanmax(yax) if yrange[1] is None else yrange[1]
zrange[0] = np.nanmin(arr) if zrange[0] is None else zrange[0]
zrange[1] = np.nanmax(arr) if zrange[1] is None else zrange[1]
xrange[0] = np.nanmin(xax.m) if xrange[0] is None else xrange[0]
xrange[1] = np.nanmax(xax.m) if xrange[1] is None else xrange[1]
yrange[0] = np.nanmin(yax.m) if yrange[0] is None else yrange[0]
yrange[1] = np.nanmax(yax.m) if yrange[1] is None else yrange[1]
zrange[0] = np.nanmin(arr.m) if zrange[0] is None else zrange[0]
zrange[1] = np.nanmax(arr.m) if zrange[1] is None else zrange[1]

# Apply log transform if requested
if log:
Expand All @@ -842,7 +903,7 @@ def plot(
ax.set_title(title, fontsize=fontsize)

try:
p = ax.pcolorfast(xax, yax, arr.T)
p = ax.pcolorfast(xax.m, yax.m, arr.m.T)

cb_kwargs = {
"orientation": "vertical",
Expand Down

0 comments on commit bd9922e

Please sign in to comment.