Skip to content

Commit

Permalink
generators & sim multi can take str function defs
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Sep 8, 2024
1 parent 6269603 commit f02869c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
3 changes: 1 addition & 2 deletions neurodsp/sim/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def sim_combined(n_seconds, fs, components, component_variances=1):
raise ValueError('Signal components and variances lengths do not match.')

# Collect the sim function to use, and repeat variance if is single number
components = {(get_sim_func(name) if isinstance(name, str) else name) : params \
for name, params in components.items()}
components = {get_sim_func(name) : params for name, params in components.items()}
variances = repeat(component_variances) if \
isinstance(component_variances, (int, float, np.number)) else iter(component_variances)

Expand Down
10 changes: 8 additions & 2 deletions neurodsp/sim/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Sized

from neurodsp.sim.info import get_sim_func
from neurodsp.utils.core import counter

###################################################################################################
Expand All @@ -12,8 +13,9 @@ def sig_yielder(sim_func, params, n_sims):
Parameters
----------
sim_func : callable
sim_func : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : dict
The parameters for the simulated signal, passed into `sim_func`.
n_sims : int, optional
Expand All @@ -26,6 +28,7 @@ def sig_yielder(sim_func, params, n_sims):
Simulated time series.
"""

sim_func = get_sim_func(sim_func)
for _ in counter(n_sims):
yield sim_func(**params)

Expand All @@ -35,8 +38,9 @@ def sig_sampler(sim_func, params, return_params=False, n_sims=None):
Parameters
----------
sim_func : callable
sim_func : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : iterable
The parameters for the simulated signal, passed into `sim_func`.
return_params : bool, optional, default: False
Expand All @@ -54,6 +58,8 @@ def sig_sampler(sim_func, params, return_params=False, n_sims=None):
Only returned if `return_params` is True.
"""

sim_func = get_sim_func(sim_func)

# If `params` has a size, and `n_sims` is defined, check that they are compatible
# To do so, we first check if the iterable has a __len__ attr, and if so check values
if isinstance(params, Sized) and len(params) and n_sims and n_sims > len(params):
Expand Down
7 changes: 6 additions & 1 deletion neurodsp/sim/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def get_sim_func(function_name, modules=SIM_MODULES):
Parameters
----------
function_name : str
function_name : str or callabe
Name of the sim function to retrieve.
If callable, returns input.
If string searches for corresponding callable sim function.
modules : list of str, optional
Which sim modules to look for the function in.
Expand All @@ -69,6 +71,9 @@ def get_sim_func(function_name, modules=SIM_MODULES):
Requested sim function.
"""

if callable(function_name):
return function_name

for module in modules:
try:
func = get_sim_funcs(module)[function_name]
Expand Down
9 changes: 8 additions & 1 deletion neurodsp/sim/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from neurodsp.sim.signals import Simulations, VariableSimulations, MultiSimulations
from neurodsp.sim.generators import sig_yielder, sig_sampler
from neurodsp.sim.params import get_base_params
from neurodsp.sim.info import get_sim_func

###################################################################################################
###################################################################################################
Expand All @@ -14,6 +15,7 @@ def sim_multiple(sim_func, params, n_sims):
----------
sim_func : callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : dict
The parameters for the simulated signal, passed into `sim_func`.
n_sims : int
Expand Down Expand Up @@ -48,6 +50,7 @@ def sim_across_values(sim_func, params):
----------
sim_func : callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : ParamIter or iterable or list of dict
Simulation parameters for `sim_func`.
Expand Down Expand Up @@ -78,6 +81,8 @@ def sim_across_values(sim_func, params):
update=getattr(params, 'update', None),
component=getattr(params, 'component', None))

sim_func = get_sim_func(sim_func)

for ind, cur_params in enumerate(params):
sims.add_signal(sim_func(**cur_params), cur_params, index=ind)

Expand All @@ -91,6 +96,7 @@ def sim_multi_across_values(sim_func, params, n_sims):
----------
sim_func : callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : ParamIter or iterable or list of dict
Simulation parameters for `sim_func`.
n_sims : int
Expand Down Expand Up @@ -132,8 +138,9 @@ def sim_from_sampler(sim_func, sampler, n_sims):
Parameters
----------
sim_func : callable
sim_func : str callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
sampler : ParamSampler
Parameter definition to sample from.
n_sims : int
Expand Down

0 comments on commit f02869c

Please sign in to comment.