Skip to content

Commit

Permalink
Merge pull request #31 from spapa013/main
Browse files Browse the repository at this point in the history
v0.1.10
  • Loading branch information
spapa013 authored Jan 9, 2024
2 parents 95f4c04 + 5e54622 commit 736c548
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 16 deletions.
26 changes: 24 additions & 2 deletions microns_utils/datajoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datajoint_plus.utils import format_rows_to_df
from datajoint_plus import base as djpb
from datajoint_plus import user_tables as djpu
from .misc_utils import classproperty, wrap
from .misc_utils import classproperty, wrap, unwrap
from .version_utils import check_package_version
from .datetime_utils import current_timestamp

Expand Down Expand Up @@ -312,4 +312,26 @@ def make(self, key):
self.on_make(key)

def on_make(self, key):
pass
pass


def run_method_from_parts(dj_table, key):
"""
Calls dj_table.r1p(key) and then run(**key).
"""
return dj_table.r1p(key).run(**key)


def get_from_parts(dj_table, key=None):
"""
Restricts and calls part.get() for each part table.
:param dj_table (dj.Table): datajoint table
:param key (dict): key to restrict by. If None, returns all rows.
Returns the result(s) as a list.
"""
result = []
for p in dj_table.parts(as_cls=True):
result.extend(p().get(key=key))
return result
145 changes: 141 additions & 4 deletions microns_utils/misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,25 @@ def __get__(self, obj, owner):
return self.f(owner)


def wrap(item):
def wrap(item, return_as_list=False):
if isinstance(item, (list, tuple)):
return item
else:
return item,
if not return_as_list:
return item,
else:
return list(item),


def unwrap(item):
def unwrap(item, enforce_one_item=False):
if isinstance(item, (list, tuple)):
if len(item) == 1:
return item[0]
else:
if enforce_one_item:
raise ValueError(f"Expected length 1, got length {len(item)}")
else:
return item
else:
return item

Expand All @@ -29,4 +37,133 @@ def sc_to_ucc(string):
"""
formats snake_case str as UpperCamelCase
"""
return ''.join(string.title().split('_'))
return ''.join(string.title().split('_'))


class FieldDict(dict):
"""
FieldDict is an enhanced dictionary that allows attribute-style access
to its keys, in addition to the standard dictionary-style access.
It also automatically converts nested dictionaries to FieldDict instances,
enabling recursive attribute-style access.
Example:
fd = FieldDict(a=1, b={'c': 2, 'd': {'e': 3}})
print(fd.a) # Outputs 1
print(fd.b.d.e) # Outputs 3
Attributes are accessed just like dictionary keys. If an attribute does not exist,
AttributeError is raised.
"""
def __init__(self, **kwargs):
"""
Initialize a FieldDict instance. All keyword arguments provided are set as keys
and values of the dictionary. Nested dictionaries are automatically converted
to FieldDict instances.
Args:
**kwargs: Arbitrary keyword arguments. Each keyword becomes a key in the
FieldDict, and the corresponding value becomes the value.
Optional keyword arguments:
_name (str, optional): The name of the FieldDict instance. Defaults to "FieldDict".
_key_disp_limit (int, optional): The maximum number of keys to display when a FieldDict object is accessed.
Key display can be disabled by setting to 0 or None. Defaults to 4.
"""
for param, default_value in self._defaults.items():
setattr(self, param, kwargs.pop(param, default_value))

super().__init__()
for key, value in kwargs.items():
self[key] = self._convert(value)

_defaults = {'_name': "FieldDict", '_key_disp_limit': 4}

def __setitem__(self, key, value):
super(FieldDict, self).__setitem__(key, self._convert(value))

def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(f"'FieldDict' object has no attribute '{name}'")

def __setattr__(self, name, value):
if name.startswith('_'):
# Handle setting of private and protected attributes normally
super(FieldDict, self).__setattr__(name, value)
else:
self[name] = self._convert(value)

def __delattr__(self, name):
if name.startswith('_'):
# Handle deletion of private and protected attributes normally
super(FieldDict, self).__delattr__(name)
else:
del self[name]

def __repr__(self):
keys = list(self.keys())
len_keys = len(keys)
key_disp_limit = self._key_disp_limit or 0
name = self._name or "FieldDict"
prefix = f"<{name} object at {hex(id(self))}"

if len_keys == 1:
key_len_repr = f" with 1 key"
else:
key_len_repr = f" with {len_keys} keys"

if len_keys == 0 or key_disp_limit == 0:
return prefix + key_len_repr + ">"

if key_disp_limit == 1:
if len_keys == 1:
key_disp_repr = f": '{keys[0]}'>"
else:
key_disp_repr = f": '{keys[0]}', ... >"
return prefix + key_len_repr + key_disp_repr

if len_keys > key_disp_limit:
key_disp_repr = ", ".join(f"'{k}'" for k in keys[:key_disp_limit//2]) + ", ..., " + ", ".join(f"'{k}'" for k in keys[-key_disp_limit//2:])
else:
key_disp_repr = ", ".join(f"'{k}'" for k in keys)
key_disp_repr = f": {key_disp_repr}>"
return prefix + key_len_repr + key_disp_repr

def get_with_path(self, path, default=None):
"""
Retrieve a value from the FieldDict using a dot-separated path. If the path
does not exist, the method returns the specified default value.
Example:
fd = FieldDict(a=1, b={'c': 2, 'd': {'e': 3}})
value = fd.get_with_path('b.d.e') # Returns 3
Args:
path (str): A dot-separated path string indicating the nested keys.
For example, 'a.b.c' refers to the path dict['a']['b']['c'].
default (any, optional): The default value to return if the path is not found.
Defaults to None.
Returns:
The value found at the path, or the default value if the path is not found.
"""
keys = path.split('.')
current = self
for key in keys:
if key in current:
current = current[key]
else:
return default
return current

gwp = get_with_path # alias for get_with_path

@staticmethod
def _convert(value):
if isinstance(value, dict) and not isinstance(value, FieldDict):
return FieldDict(**value)
elif isinstance(value, (list, set, tuple)):
return type(value)(FieldDict._convert(v) for v in value)
return value
32 changes: 32 additions & 0 deletions microns_utils/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from scipy.stats import gaussian_kde
from scipy import ndimage
from .misc_utils import wrap


Expand Down Expand Up @@ -34,6 +35,37 @@ def normalize_image(image, newrange=[0, 255], clip_bounds=None, astype=np.uint8)
return (((image - image.min())*(newrange[1]-newrange[0])/(image.max() - image.min())) + newrange[0]).astype(astype)


def lcn(image, sigmas=(12, 12)):
""" Local contrast normalization.
Normalize each pixel using mean and stddev computed on a local neighborhood.
We use gaussian filters rather than uniform filters to compute the local mean and std
to soften the effect of edges. Essentially we are using a fuzzy local neighborhood.
Equivalent using a hard defintion of neighborhood will be:
local_mean = ndimage.uniform_filter(image, size=(32, 32))
:param np.array image: Array with raw two-photon images.
:param tuple sigmas: List with sigmas (one per axis) to use for the gaussian filter.
Smaller values result in more local neighborhoods. 15-30 microns should work fine
"""
local_mean = ndimage.gaussian_filter(image, sigmas)
local_var = ndimage.gaussian_filter(image ** 2, sigmas) - local_mean ** 2
local_std = np.sqrt(np.clip(local_var, a_min=0, a_max=None))
norm = (image - local_mean) / (local_std + 1e-7)
return norm


def sharpen_2pimage(image, laplace_sigma=0.7, low_percentile=3, high_percentile=99.9):
""" Apply a laplacian filter, clip pixel range and normalize.
:param np.array image: Array with raw two-photon images.
:param float laplace_sigma: Sigma of the gaussian used in the laplace filter.
:param float low_percentile, high_percentile: Percentiles at which to clip.
:returns: Array of same shape as input. Sharpened image.
"""
sharpened = image - ndimage.gaussian_laplace(image, laplace_sigma)
clipped = np.clip(sharpened, *np.percentile(sharpened, [low_percentile, high_percentile]))
norm = (clipped - clipped.mean()) / (clipped.max() - clipped.min() + 1e-7)
return norm


def run_kde(data, nbins, bounds='auto', method='gaussian_kde', method_kws=None):
"""
Generate kernel density estimation from data
Expand Down
2 changes: 1 addition & 1 deletion microns_utils/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.9'
__version__ = '0.1.10'
22 changes: 13 additions & 9 deletions microns_utils/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import traceback
import requests
from .misc_utils import wrap
try:
from importlib import metadata
except ImportError:
Expand All @@ -15,6 +16,7 @@
from pathlib import Path
from .filepath_utils import find_all_matching_files

logger = logging.getLogger(__name__)

def parse_version(text: str):
"""
Expand Down Expand Up @@ -60,22 +62,22 @@ def check_latest_version_from_github(owner, repo, source, branch='main', path_to
elif source == 'tag':
f = requests.get(f"https://api.github.com/repos/{owner}/{repo}/tags")
if not f.ok:
logging.error(f'Could not check Github version because: "{f.reason}".')
logger.error(f'Could not check Github version because: "{f.reason}".')
return latest
latest = parse_version(json.loads(f.text)[0]['name'][1:])

elif source == 'release':
f = requests.get(f"https://api.github.com/repos/{owner}/{repo}/releases")
if not f.ok:
logging.error(f'Could not check Github version because: "{f.reason}".')
logger.error(f'Could not check Github version because: "{f.reason}".')
return latest
latest = parse_version(json.loads(f.text)[0]['tag_name'][1:])

else:
raise ValueError(f'source: "{source}" not recognized. Options include: "commit", "tag", "release". ')
except:
if warn:
logging.warning('Failed to check latest version from Github.')
logger.warning('Failed to check latest version from Github.')
traceback.print_exc()

return latest
Expand Down Expand Up @@ -115,7 +117,7 @@ def check_package_version_from_distributions(package, warn=True):
version = [dist.version for dist in metadata.distributions() if dist.metadata["Name"] == package]
if not version:
if warn:
logging.warning('Package not found in distributions.')
logger.warning('Package not found in distributions.')
return ''
return version[0]

Expand All @@ -136,21 +138,21 @@ def check_package_version_from_sys_path(package, path_to_version_file, prefix=''

if len(paths)>1:
if warn:
logging.warning(err_base_str + f'{len(paths)} paths containing {package} were found in sys.path. Consider adding a prefix for further specification.')
logger.warning(err_base_str + f'{len(paths)} paths containing {package} were found in sys.path. Consider adding a prefix for further specification.')
[print(p) for p in paths]
return ''

elif len(paths) == 0:
if warn:
logging.warning(err_base_str + f'no paths matching {package} were found in sys.path.')
logger.warning(err_base_str + f'no paths matching {package} were found in sys.path.')
return ''

else:
files = find_all_matching_files('version.py', paths[0])

if len(files) == 0:
if warn:
logging.warning(err_base_str + 'no version.py file was found.')
logger.warning(err_base_str + 'no version.py file was found.')
return ''

else:
Expand Down Expand Up @@ -190,6 +192,8 @@ def check_package_version(package, prefix='', check_if_latest=False, check_if_la

if __version__ != latest:
if warn:
logging.warning(f'You are using {package} version {__version__}, which does not match the latest version on Github, {latest}.')
logger.warning(f'You are using {package} version {__version__}, which does not match the latest version on Github, {latest}.')

return __version__
return __version__


0 comments on commit 736c548

Please sign in to comment.