Skip to content

Commit

Permalink
revamp tests; slight tweak to tranform_params function
Browse files Browse the repository at this point in the history
  • Loading branch information
jbial committed Dec 9, 2024
1 parent 216337b commit 9a65c9c
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 163 deletions.
17 changes: 9 additions & 8 deletions dysts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from functools import partial
from importlib import resources
from itertools import starmap
from typing import Any, Callable, Sequence

import numpy as np
Expand Down Expand Up @@ -190,17 +189,19 @@ def transform_params(
self, transform_fn: Callable[[str, np.ndarray, Any], np.ndarray | None]
) -> bool:
"""Updates the current parameter list via a transform function"""
transformed_params = list(
starmap(
partial(transform_fn, system=self), # type: ignore
zip(sorted(self.params.keys()), self.param_list),
transformed_params = {
param_name: transform_fn(param_name, param_value, system=self) # type: ignore
for param_name, param_value in zip(
sorted(self.params.keys()), self.param_list
)
)
}

if any(p is None for p in transformed_params):
if any(p is None for p in transformed_params.values()):
return False

self.param_list = transformed_params
self.params = transformed_params
self.__dict__.update(self.params)
self.param_list = [self.params[key] for key in sorted(self.params.keys())]
return True

def make_trajectory(self, *args, **kwargs):
Expand Down
68 changes: 60 additions & 8 deletions tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import numpy as np

import dysts.flows as dfl
from dysts.sampling import GaussianInitialConditionSampler, GaussianParamSampler
from dysts.systems import get_attractor_list, make_trajectory_ensemble


Expand Down Expand Up @@ -72,16 +74,66 @@ def test_trajectories(self):
self.assertEqual(len(trajs), num_trials)

def test_ensemble_generation_initial_condition_sampling(self):
"""
TODO: add tests for initial condition sampling
"""
pass
ic_sampler = GaussianInitialConditionSampler(
scale=1e-4, random_seed=random.randint(0, 1000000)
)
system_sample = random.sample(get_attractor_list(sys_class="continuous"), 4)
systems = [getattr(dfl, sys)() for sys in system_sample]
unperturbed_sols = make_trajectory_ensemble(
256,
pts_per_period=64,
use_multiprocessing=True,
subset=systems,
)

for sys in systems:
sys.transform_ic(ic_sampler)

perturbed_sols = make_trajectory_ensemble(
256,
pts_per_period=64,
use_multiprocessing=True,
subset=systems,
)

for system_name in system_sample:
unperturbed_traj = unperturbed_sols[system_name]
perturbed_traj = perturbed_sols[system_name]
self.assertTrue(unperturbed_traj is not None)
self.assertTrue(perturbed_traj is not None)
self.assertEqual(unperturbed_traj.shape, perturbed_traj.shape)
self.assertFalse(np.allclose(unperturbed_traj, perturbed_traj))

def test_ensemble_generation_parameter_sampling(self):
"""
TODO: add tests for parameter sampling
"""
pass
param_sampler = GaussianParamSampler(
scale=1e-4, random_seed=random.randint(0, 1000000)
)
system_sample = random.sample(get_attractor_list(sys_class="continuous"), 4)
systems = [getattr(dfl, sys)() for sys in system_sample]
unperturbed_sols = make_trajectory_ensemble(
256,
pts_per_period=64,
use_multiprocessing=True,
subset=systems,
)

for sys in systems:
sys.transform_params(param_sampler)

perturbed_sols = make_trajectory_ensemble(
256,
pts_per_period=64,
use_multiprocessing=True,
subset=systems,
)

for system_name in system_sample:
unperturbed_traj = unperturbed_sols[system_name]
perturbed_traj = perturbed_sols[system_name]
self.assertTrue(unperturbed_traj is not None)
self.assertTrue(perturbed_traj is not None)
self.assertEqual(unperturbed_traj.shape, perturbed_traj.shape)
self.assertFalse(np.allclose(unperturbed_traj, perturbed_traj))


if __name__ == "__main__":
Expand Down
135 changes: 61 additions & 74 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,30 @@

#!/usr/bin/env python
import os
import random
import sys
import unittest

import numpy as np

import dysts.flows as dfl
import dysts.maps as dmp
from dysts.base import DynMap, DynSys, DynSysDelay
from dysts.flows import Lorenz
from dysts.systems import get_attractor_list, make_trajectory_ensemble
from dysts.systems import get_attractor_list

WORKING_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_PATH = os.path.join(WORKING_DIR, "tests", "test_data")
print(WORKING_DIR)

sys.path.insert(1, os.path.join(WORKING_DIR, "dysts"))

NUM_TEST_SYSTEMS = 10


class TestModels(unittest.TestCase):
"""
Tests integration and models
Tests integration
"""

def test_trajectory(self):
Expand All @@ -50,79 +55,61 @@ def test_trajectory_noise(self):
assert sol is not None, "Generated trajectory is None"
assert sol.shape == (100, 3), "Generated time series has the wrong shape" # type: ignore

## Test removed due to the need to re-generate the reference data every time
## a new system is added to the database
# def test_ensemble(self):
# """
# Test all systems in the database
# """
# all_trajectories = make_trajectory_ensemble(5, method="Radau", resample=True)
# assert len(all_trajectories.keys()) >= 131

# xvals = np.array([all_trajectories[key][:, 0] for key in all_trajectories.keys()])
# xvals_reference = np.load(os.path.join(DATA_PATH, "all_trajectories.npy"), allow_pickle=True)
# diff_names = np.array(list(all_trajectories.keys()))[np.sum(np.abs(xvals - xvals_reference), axis=1) > 0]
# assert np.allclose(xvals, xvals_reference), "Generated trajectories do not match reference values for system {}".format(diff_names)

# # TODO: make sure a data file exists in the data folder that is referenced
# def test_precomputed(self):
# """
# Test loading a precomputed time series for a single system
# """
# dyst_name = "Lorenz"
# eq = getattr(dfl, dyst_name)()
# dyst_data_path = os.path.join(DATA_PATH, f"{dyst_name}.json.gz")
# if not os.path.exists(dyst_data_path):
# raise FileNotFoundError(f"File {dyst_data_path} does not exist")

# tpts, sol = eq.load_trajectory(
# data_path=DATA_PATH,
# standardize=True,
# return_times=True,
# )
# assert sol.shape == (1200, 3), "Generated time series has the wrong shape"
# assert tpts.shape == (1200,), "Time indices have the wrong shape"


class TestMakeTrajectoryEnsemble(unittest.TestCase):
def test_ensemble(self):
# Test that the function returns a dictionary with the correct keys
n = 100
subset = ["Lorenz", "Rossler"]
kwargs = {"method": "Radau"}
ensemble = make_trajectory_ensemble(
n,
subset=subset,
**kwargs, # type: ignore
def test_random_continuous_systems(self):
continuous_systems = get_attractor_list(sys_class="continuous_no_delay")
random_systems = random.sample(
continuous_systems, min(NUM_TEST_SYSTEMS, len(continuous_systems))
)

for system_name in random_systems:
with self.subTest(system=system_name):
print(f"Testing {system_name}")
system = getattr(dfl, system_name)()
self.assertIsInstance(system, DynSys)

sol = system.make_trajectory(256, return_times=True)
self.assertIsInstance(sol, tuple)
self.assertEqual(len(sol), 2)
self.assertIsInstance(sol[0], np.ndarray)
self.assertIsInstance(sol[1], np.ndarray)
self.assertEqual(sol[0].shape[0], 256)
self.assertEqual(sol[1].shape[0], 256)

def test_random_delay_systems(self):
delay_systems = get_attractor_list(sys_class="delay")
random_systems = random.sample(
delay_systems, min(NUM_TEST_SYSTEMS, len(delay_systems))
)
self.assertIsInstance(ensemble, dict)
self.assertEqual(set(ensemble.keys()), set(subset))

# Test that the function returns the correct number of timepoints
for key in ensemble:
self.assertEqual(ensemble[key].shape[0], n)

# Test that the function returns the correct shape of the solution array
for key in ensemble:
self.assertEqual(ensemble[key].shape[1], len(getattr(dfl, key)().ic))

# Test that the function returns the correct shape of the solution array
for key in ensemble:
self.assertEqual(ensemble[key].shape[0], n)

def test_multiprocessing(self):
# Test that the function returns a warning when multiprocessing is set to True
n = 100
subset = ["Lorenz", "Rossler"]
kwargs = {"method": "Radau"}
with self.assertRaises(Exception):
with self.assertWarns(UserWarning):
make_trajectory_ensemble(
n,
subset=subset,
use_multiprocessing=True,
**kwargs, # type: ignore
)

for system_name in random_systems:
with self.subTest(system=system_name):
print(f"Testing {system_name}")
system = getattr(dfl, system_name)()
self.assertIsInstance(system, DynSysDelay)

sol = system.make_trajectory(256, return_times=True)
self.assertIsInstance(sol, tuple)
self.assertEqual(len(sol), 2)
self.assertIsInstance(sol[0], np.ndarray)
self.assertIsInstance(sol[1], np.ndarray)
self.assertEqual(sol[0].shape[0], 256)
self.assertEqual(sol[1].shape[0], 256)

def test_random_discrete_maps(self):
discrete_maps = get_attractor_list(sys_class="discrete")
random_systems = random.sample(
discrete_maps, min(NUM_TEST_SYSTEMS, len(discrete_maps))
)

for system_name in random_systems:
with self.subTest(system=system_name):
print(f"Testing {system_name}")
system = getattr(dmp, system_name)()
self.assertIsInstance(system, DynMap)

sol = system.make_trajectory(256)
self.assertIsInstance(sol, np.ndarray)
self.assertEqual(sol.shape[0], 256)


class TestJacobian(unittest.TestCase):
Expand Down
73 changes: 0 additions & 73 deletions tests/test_integration.py

This file was deleted.

0 comments on commit 9a65c9c

Please sign in to comment.