Skip to content

Commit

Permalink
fixed sklearn import req and some linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
abao1999 committed Sep 19, 2024
1 parent 37700ce commit aadae8c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 25 deletions.
14 changes: 10 additions & 4 deletions dysts/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import warnings

import numpy as np
from scipy.spatial.distance import cdist
from scipy.stats import linregress
from scipy.spatial.distance import cdist # type: ignore
from scipy.stats import linregress # type: ignore

from .flows import DynSys
from .base import DynSys
from .utils import (
ComputationHolder,
find_significant_frequencies,
Expand All @@ -21,7 +21,11 @@
)

if has_module("sklearn"):
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import RidgeCV # type: ignore
else:
warnings.warn(
"Sklearn not installed. Will not be able to use ridge regression for gpdistance and corr_gpdim."
)


def sample_initial_conditions(
Expand Down Expand Up @@ -226,6 +230,8 @@ def gpdistance(traj1, traj2, standardize=True, register=False, **kwargs):
"""

if register:
if not has_module("sklearn"):
raise ImportError("Sklearn is required for registration")
model = RidgeCV()
model.fit(traj1, traj2)
traj1 = model.predict(traj1)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = ["numpy", "scipy", "tqdm"]
benchmarks = ["darts", "torch", "scikit-learn", "pandas", "tsfresh", "nolds"]
numba = ["numba"]
tests = ["matplotlib"]
extra = ["sdeint", "sklearn"]
extra = ["sdeint", "scikit-learn"]

# Package data
[tool.setuptools]
Expand Down
28 changes: 8 additions & 20 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt

import dysts.flows as dfl
from dysts.analysis import compute_timestep

if __name__ == "__main__":
num_points_per_period = 1024
Expand All @@ -10,26 +9,15 @@
dyst_name = "Lorenz"
system = getattr(dfl, dyst_name)()

traj = system.make_trajectory(
num_points,
resample=True,
dt = compute_timestep(
system,
total_length=num_points,
transient_fraction=0.2,
num_iters=5,
pts_per_period=num_points_per_period,
postprocess=False,
return_period=True,
)

print(traj.shape)
plt.plot(traj)
plt.show()

# dt = compute_timestep(
# system,
# total_length=num_points,
# transient_fraction=0.2,
# num_iters=5,
# pts_per_period=num_points_per_period,
# return_period=True,
# )
# print("result: ", dt)
print("result: ", dt)

# TODO: Use scipy optimize for black box optimization of dt from initial guess
# until it meets characteristic timescale criteria
Expand Down

0 comments on commit aadae8c

Please sign in to comment.