Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dfof, eval, other final things #87

Merged
merged 2 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mesmerize_core/caiman_extensions/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,40 @@ def _use_cache(instance, *args, **kwargs):
return _return_wrapper(return_val, copy_bool=return_copy)

return _use_cache

def invalidate(self, pre: bool = True, post: bool = True):
"""
invalidate all cache entries associated to a single batch item

Parameters
----------
pre: bool
invalidate before the decorated function has been fun

post: bool
invalidate after the decorated function has been fun

"""
def _invalidate(func):
@wraps(func)
def __invalidate(instance, *args, **kwargs):
u = instance._series["uuid"]

if pre:
self.cache.drop(
self.cache.loc[self.cache["uuid"] == u].index,
inplace=True
)

rval = func(instance, *args, **kwargs)

if post:
self.cache.drop(
self.cache.loc[self.cache["uuid"] == u].index,
inplace=True
)

return rval

return __invalidate
return _invalidate
189 changes: 187 additions & 2 deletions mesmerize_core/caiman_extensions/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from caiman.source_extraction.cnmf.cnmf import load_CNMF
from caiman.utils.visualization import get_contours as caiman_get_contours
from functools import wraps
import os
from copy import deepcopy

from .common import validate
from .cache import Cache
Expand Down Expand Up @@ -60,6 +62,20 @@ def _parser(instance, *args, **kwargs) -> Any:
return _parser


def _check_permissions(func):
@wraps(func)
def __check(instance, *args, **kwargs):
cnmf_obj_path = instance.get_output_path()

if not os.access(cnmf_obj_path, os.W_OK):
raise PermissionError(
"You do not have write access to the hdf5 output file for this batch item"
)

return func(instance, *args, **kwargs)
return __check


@pd.api.extensions.register_series_accessor("cnmf")
class CNMFExtensions:
"""
Expand Down Expand Up @@ -145,7 +161,7 @@ def get_masks(
) -> np.ndarray:
"""
| Get binary masks of the spatial components at the given ``component_indices``.
| Created from cnmf.estimates.A.
| Created from ``CNMF.estimates.A``

Parameters
----------
Expand Down Expand Up @@ -262,7 +278,7 @@ def get_temporal(
self, component_indices: Union[np.ndarray, str] = None, add_background: bool = False, return_copy=True
) -> np.ndarray:
"""
Get the temporal components for this CNMF item, basically ``cnmf.estimates.C``
Get the temporal components for this CNMF item, basically ``CNMF.estimates.C``

Parameters
----------
Expand Down Expand Up @@ -430,3 +446,172 @@ def get_residuals(
residuals = raw_movie[np.arange(*frame_indices)] - reconstructed_movie - background

return residuals.reshape(cnmf_obj.dims + (-1,), order="F").transpose([2, 0, 1])

@validate("cnmf")
@_check_permissions
@cache.invalidate()
def run_detrend_dfof(
self,
quantileMin: float = 8,
frames_window: int = 500,
flag_auto: bool = True,
use_fast: bool = False,
use_residuals: bool = True,
detrend_only: bool = False
) -> None:
"""
| Uses caiman's detrend_df_f.
| call ``cnmf.get_detrend_dfof()`` to get the values.
| Sets ``CNMF.estimates.F_dff``

Warnings
--------
Overwrites the existing cnmf hdf5 outfile file for this batch item

Parameters
----------
quantile_min: float
quantile used to estimate the baseline (values in [0,100])
used only if 'flag_auto' is False, i.e. ignored by default

frames_window: int
number of frames for computing running quantile

flag_auto: bool
flag for determining quantile automatically

use_fast: bool
flag for using approximate fast percentile filtering

detrend_only: bool
flag for only subtracting baseline and not normalizing by it.
Used in 1p data processing where baseline fluorescence cannot be
determined.

Returns
-------
None

Notes
------
invalidates the cache for this batch item.

"""

cnmf_obj: CNMF = self.get_output()
cnmf_obj.estimates.detrend_df_f(
quantileMin=quantileMin,
frames_window=frames_window,
flag_auto=flag_auto,
use_fast=use_fast,
use_residuals=use_residuals,
detrend_only=detrend_only
)

# remove current hdf5 file
cnmf_obj_path = self.get_output_path()
cnmf_obj_path.unlink()

# save new hdf5 file with new F_dff vals
cnmf_obj.save(str(cnmf_obj_path))

@validate("cnmf")
@cache.use_cache
def get_detrend_dfof(self):
cnmf_obj = self.get_output()
if cnmf_obj.estimates.F_dff is None:
raise AttributeError("You must run ``cnmf.run_detrend_dfof()`` first")

return cnmf_obj.estimates.F_dff

@validate("cnmf")
@_check_permissions
@cache.invalidate()
def run_eval(self, params: dict) -> None:
"""
Run component evaluation

Warnings
--------
Overwrites the existing cnmf hdf5 outfile file for this batch item

Parameters
----------
params: dict
dict of parameters for component evaluation

============== =================
parameter details
============== =================
SNR_lowest ``float``, minimum accepted SNR value
cnn_lowest ``float``, minimum accepted value for CNN classifier
gSig_range ``List[int, int]`` or ``None``, range for gSig scale for CNN classifier
min_SNR ``float``, transient SNR threshold
min_cnn_thr ``float``, threshold for CNN classifier
rval_lowest ``float``, minimum accepted space correlation
rval_thr ``float``, space correlation threshold
use_cnn ``bool``, use CNN based classifier
use_ecc ``bool``, flag for eccentricity based filtering
max_ecc ``float``, max eccentricity
============== =================

Returns
-------
None

Notes
------
invalidates the cache for this batch item.

"""

cnmf_obj = self.get_output()

valid = list(cnmf_obj.params.quality.keys())
for k in params.keys():
if k not in valid:
raise KeyError(
f"passed params dict key `{k}` is not a valid parameter for quality evaluation\n"
f"valid param keys are: {valid}"
)

cnmf_obj.params.quality.update(params)
cnmf_obj.estimates.filter_components(
imgs=self.get_input_memmap(),
params=cnmf_obj.params
)

cnmf_obj_path = self.get_output_path()
cnmf_obj_path.unlink()

cnmf_obj.save(str(cnmf_obj_path))
self._series["params"]["eval"] = deepcopy(params)

@validate("cnmf")
def get_good_components(self) -> np.ndarray:
"""
get the good component indices, ``CNMF.estimates.idx_components``

Returns
-------
np.ndarray
array of ints, indices of good components

"""

cnmf_obj = self.get_output()
return cnmf_obj.estimates.idx_components

@validate("cnmf")
def get_bad_components(self) -> np.ndarray:
"""
get the bad component indices, ``CNMF.estimates.idx_components_bad``

Returns
-------
np.ndarray
array of ints, indices of bad components

"""
cnmf_obj = self.get_output()
return cnmf_obj.estimates.idx_components_bad