Skip to content

Commit

Permalink
Merge pull request #87 from nel-lab/cache-invalidate
Browse files Browse the repository at this point in the history
dfof, eval, other final things
  • Loading branch information
kushalkolar committed Aug 1, 2022
2 parents 5f62bab + f603711 commit ce35f88
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 2 deletions.
37 changes: 37 additions & 0 deletions mesmerize_core/caiman_extensions/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,40 @@ def _use_cache(instance, *args, **kwargs):
return _return_wrapper(return_val, copy_bool=return_copy)

return _use_cache

def invalidate(self, pre: bool = True, post: bool = True):
"""
invalidate all cache entries associated to a single batch item
Parameters
----------
pre: bool
invalidate before the decorated function has been fun
post: bool
invalidate after the decorated function has been fun
"""
def _invalidate(func):
@wraps(func)
def __invalidate(instance, *args, **kwargs):
u = instance._series["uuid"]

if pre:
self.cache.drop(
self.cache.loc[self.cache["uuid"] == u].index,
inplace=True
)

rval = func(instance, *args, **kwargs)

if post:
self.cache.drop(
self.cache.loc[self.cache["uuid"] == u].index,
inplace=True
)

return rval

return __invalidate
return _invalidate
189 changes: 187 additions & 2 deletions mesmerize_core/caiman_extensions/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from caiman.source_extraction.cnmf.cnmf import load_CNMF
from caiman.utils.visualization import get_contours as caiman_get_contours
from functools import wraps
import os
from copy import deepcopy

from .common import validate
from .cache import Cache
Expand Down Expand Up @@ -60,6 +62,20 @@ def _parser(instance, *args, **kwargs) -> Any:
return _parser


def _check_permissions(func):
@wraps(func)
def __check(instance, *args, **kwargs):
cnmf_obj_path = instance.get_output_path()

if not os.access(cnmf_obj_path, os.W_OK):
raise PermissionError(
"You do not have write access to the hdf5 output file for this batch item"
)

return func(instance, *args, **kwargs)
return __check


@pd.api.extensions.register_series_accessor("cnmf")
class CNMFExtensions:
"""
Expand Down Expand Up @@ -145,7 +161,7 @@ def get_masks(
) -> np.ndarray:
"""
| Get binary masks of the spatial components at the given ``component_indices``.
| Created from cnmf.estimates.A.
| Created from ``CNMF.estimates.A``
Parameters
----------
Expand Down Expand Up @@ -262,7 +278,7 @@ def get_temporal(
self, component_indices: Union[np.ndarray, str] = None, add_background: bool = False, return_copy=True
) -> np.ndarray:
"""
Get the temporal components for this CNMF item, basically ``cnmf.estimates.C``
Get the temporal components for this CNMF item, basically ``CNMF.estimates.C``
Parameters
----------
Expand Down Expand Up @@ -430,3 +446,172 @@ def get_residuals(
residuals = raw_movie[np.arange(*frame_indices)] - reconstructed_movie - background

return residuals.reshape(cnmf_obj.dims + (-1,), order="F").transpose([2, 0, 1])

@validate("cnmf")
@_check_permissions
@cache.invalidate()
def run_detrend_dfof(
self,
quantileMin: float = 8,
frames_window: int = 500,
flag_auto: bool = True,
use_fast: bool = False,
use_residuals: bool = True,
detrend_only: bool = False
) -> None:
"""
| Uses caiman's detrend_df_f.
| call ``cnmf.get_detrend_dfof()`` to get the values.
| Sets ``CNMF.estimates.F_dff``
Warnings
--------
Overwrites the existing cnmf hdf5 outfile file for this batch item
Parameters
----------
quantile_min: float
quantile used to estimate the baseline (values in [0,100])
used only if 'flag_auto' is False, i.e. ignored by default
frames_window: int
number of frames for computing running quantile
flag_auto: bool
flag for determining quantile automatically
use_fast: bool
flag for using approximate fast percentile filtering
detrend_only: bool
flag for only subtracting baseline and not normalizing by it.
Used in 1p data processing where baseline fluorescence cannot be
determined.
Returns
-------
None
Notes
------
invalidates the cache for this batch item.
"""

cnmf_obj: CNMF = self.get_output()
cnmf_obj.estimates.detrend_df_f(
quantileMin=quantileMin,
frames_window=frames_window,
flag_auto=flag_auto,
use_fast=use_fast,
use_residuals=use_residuals,
detrend_only=detrend_only
)

# remove current hdf5 file
cnmf_obj_path = self.get_output_path()
cnmf_obj_path.unlink()

# save new hdf5 file with new F_dff vals
cnmf_obj.save(str(cnmf_obj_path))

@validate("cnmf")
@cache.use_cache
def get_detrend_dfof(self):
cnmf_obj = self.get_output()
if cnmf_obj.estimates.F_dff is None:
raise AttributeError("You must run ``cnmf.run_detrend_dfof()`` first")

return cnmf_obj.estimates.F_dff

@validate("cnmf")
@_check_permissions
@cache.invalidate()
def run_eval(self, params: dict) -> None:
"""
Run component evaluation
Warnings
--------
Overwrites the existing cnmf hdf5 outfile file for this batch item
Parameters
----------
params: dict
dict of parameters for component evaluation
============== =================
parameter details
============== =================
SNR_lowest ``float``, minimum accepted SNR value
cnn_lowest ``float``, minimum accepted value for CNN classifier
gSig_range ``List[int, int]`` or ``None``, range for gSig scale for CNN classifier
min_SNR ``float``, transient SNR threshold
min_cnn_thr ``float``, threshold for CNN classifier
rval_lowest ``float``, minimum accepted space correlation
rval_thr ``float``, space correlation threshold
use_cnn ``bool``, use CNN based classifier
use_ecc ``bool``, flag for eccentricity based filtering
max_ecc ``float``, max eccentricity
============== =================
Returns
-------
None
Notes
------
invalidates the cache for this batch item.
"""

cnmf_obj = self.get_output()

valid = list(cnmf_obj.params.quality.keys())
for k in params.keys():
if k not in valid:
raise KeyError(
f"passed params dict key `{k}` is not a valid parameter for quality evaluation\n"
f"valid param keys are: {valid}"
)

cnmf_obj.params.quality.update(params)
cnmf_obj.estimates.filter_components(
imgs=self.get_input_memmap(),
params=cnmf_obj.params
)

cnmf_obj_path = self.get_output_path()
cnmf_obj_path.unlink()

cnmf_obj.save(str(cnmf_obj_path))
self._series["params"]["eval"] = deepcopy(params)

@validate("cnmf")
def get_good_components(self) -> np.ndarray:
"""
get the good component indices, ``CNMF.estimates.idx_components``
Returns
-------
np.ndarray
array of ints, indices of good components
"""

cnmf_obj = self.get_output()
return cnmf_obj.estimates.idx_components

@validate("cnmf")
def get_bad_components(self) -> np.ndarray:
"""
get the bad component indices, ``CNMF.estimates.idx_components_bad``
Returns
-------
np.ndarray
array of ints, indices of bad components
"""
cnmf_obj = self.get_output()
return cnmf_obj.estimates.idx_components_bad

0 comments on commit ce35f88

Please sign in to comment.