Skip to content

Commit

Permalink
ENH: Offload multiple_trajectories guard code to helper func.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jul 11, 2023
1 parent 415a63b commit a279e11
Showing 1 changed file with 63 additions and 59 deletions.
122 changes: 63 additions & 59 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import product
from typing import Collection
from typing import Sequence
from typing import Union

import numpy as np
from scipy.integrate import odeint
Expand Down Expand Up @@ -180,7 +181,6 @@ def fit(
t=None,
x_dot=None,
u=None,
multiple_trajectories=False,
unbias=True,
):
"""
Expand Down Expand Up @@ -221,12 +221,6 @@ def fit(
for each trajectory. Individual trajectories may contain different
numbers of samples.
multiple_trajectories: boolean, optional, (default False)
Whether or not the training data includes multiple trajectories. If
True, the training data must be a list of arrays containing data
for each trajectory. If False, the training data must be a single
array.
unbias: boolean, optional (default True)
Whether to perform an extra step of unregularized linear regression to
unbias the coefficients for the identified support.
Expand All @@ -246,18 +240,8 @@ def fit(
if t is None:
t = self.t_default

if not multiple_trajectories:
if not _check_multiple_trajectories(x, x_dot, u):
x, t, x_dot, u = _adapt_to_multiple_trajectories(x, t, x_dot, u)
multiple_trajectories = True
elif (
not isinstance(x, Sequence)
or (not isinstance(x_dot, Sequence) and x_dot is not None)
or (not isinstance(u, Sequence) and u is not None)
):
raise TypeError(
"If multiple trajectories set, x and if included,"
"x_dot and u, must be Sequences"
)
x, x_dot, u = _comprehend_and_validate_inputs(
x, t, x_dot, u, self.feature_library
)
Expand All @@ -271,7 +255,7 @@ def fit(
trim_last_point=(self.discrete_time and x_dot is None),
)
self.n_control_features_ = u[0].shape[u[0].ax_coord]
x, x_dot = self._process_multiple_trajectories(x, t, x_dot)
x, x_dot = self._process_trajectories(x, t, x_dot)

# Append control variables
if u is not None:
Expand Down Expand Up @@ -304,7 +288,7 @@ def fit(

return self

def predict(self, x, u=None, multiple_trajectories=False):
def predict(self, x, u=None):
"""
Predict the time derivatives using the SINDy model.
Expand All @@ -319,17 +303,17 @@ def predict(self, x, u=None, multiple_trajectories=False):
must be a list of control variable data from each trajectory. If the
model was fit with control variables then u is not optional.
multiple_trajectories: boolean, optional (default False)
If True, x contains multiple trajectories and must be a list of
data from each trajectory. If False, x is a single trajectory.
Returns
-------
x_dot: array-like or list of array-like, shape (n_samples, n_input_features)
Predicted time derivatives
"""
if not multiple_trajectories:
if not _check_multiple_trajectories(x, None, u):
x, _, _, u = _adapt_to_multiple_trajectories(x, None, None, u)
multiple_trajectories = False
else:
multiple_trajectories = True

x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)

check_is_fitted(self, "model")
Expand Down Expand Up @@ -415,16 +399,7 @@ def print(self, lhs=None, precision=3):
else:
print(lhs[i] + " = " + eqn)

def score(
self,
x,
t=None,
x_dot=None,
u=None,
multiple_trajectories=False,
metric=r2_score,
**metric_kws
):
def score(self, x, t=None, x_dot=None, u=None, metric=r2_score, **metric_kws):
"""
Returns a score for the time derivative prediction produced by the model.
Expand Down Expand Up @@ -453,10 +428,6 @@ def score(
must be a list of control variable data from each trajectory.
If the model was fit with control variables then u is not optional.
multiple_trajectories: boolean, optional (default False)
If True, x contains multiple trajectories and must be a list of
data from each trajectory. If False, x is a single trajectory.
metric: callable, optional
Metric function with which to score the prediction. Default is the
R^2 coefficient of determination.
Expand All @@ -476,27 +447,26 @@ def score(
if t is None:
t = self.t_default

if not multiple_trajectories:
if not _check_multiple_trajectories(x, x_dot, u):
x, t, x_dot, u = _adapt_to_multiple_trajectories(x, t, x_dot, u)
multiple_trajectories = True
x, x_dot, u = _comprehend_and_validate_inputs(
x, t, x_dot, u, self.feature_library
)

x_dot_predict = self.predict(x, u, multiple_trajectories=multiple_trajectories)
x_dot_predict = self.predict(x, u)

if self.discrete_time and x_dot is None:
x_dot_predict = [xd[:-1] for xd in x_dot_predict]

x, x_dot = self._process_multiple_trajectories(x, t, x_dot)
x, x_dot = self._process_trajectories(x, t, x_dot)

x_dot = concat_sample_axis(x_dot)
x_dot_predict = concat_sample_axis(x_dot_predict)

x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)
return metric(x_dot, x_dot_predict, **metric_kws)

def _process_multiple_trajectories(self, x, t, x_dot):
def _process_trajectories(self, x, t, x_dot):
"""
Calculate derivatives of input data, iterating through trajectories.
Expand Down Expand Up @@ -544,7 +514,7 @@ def _process_multiple_trajectories(self, x, t, x_dot):
)
return x, x_dot

def differentiate(self, x, t=None, multiple_trajectories=False):
def differentiate(self, x, t=None):
"""
Apply the model's differentiation method
(:code:`self.differentiation_method`) to data.
Expand All @@ -559,10 +529,6 @@ def differentiate(self, x, t=None, multiple_trajectories=False):
Time step between samples or array of collection times.
If None, the default time step ``t_default`` will be used.
multiple_trajectories: boolean, optional (default False)
If True, x contains multiple trajectories and must be a list of
data from each trajectory. If False, x is a single trajectory.
Returns
-------
x_dot: array-like or list of array-like, shape (n_samples, n_input_features)
Expand All @@ -573,12 +539,15 @@ def differentiate(self, x, t=None, multiple_trajectories=False):
t = self.t_default
if self.discrete_time:
raise RuntimeError("No differentiation implemented for discrete time model")
if not multiple_trajectories:
if not _check_multiple_trajectories(x, None, None):
x, t, _, _ = _adapt_to_multiple_trajectories(x, t, None, None)
multiple_trajectories = False
else:
multiple_trajectories = True
x, _, _ = _comprehend_and_validate_inputs(
x, t, None, None, self.feature_library
)
result = self._process_multiple_trajectories(x, t, None)[1]
result = self._process_trajectories(x, t, None)[1]
if not multiple_trajectories:
return result[0]
return result
Expand Down Expand Up @@ -788,23 +757,58 @@ def _zip_like_sequence(x, t):
return product(x, [t])


def _adapt_to_multiple_trajectories(x, t, x_dot, u):
"""Adapt model data not already in multiple_trajectories to that format.
def _check_multiple_trajectories(x, x_dot, u) -> bool:
"""Determine if data contains multiple trajectories
Arguments:
Args:
x: Samples from which to make predictions.
t: Time step between samples or array of collection times.
x_dot: Pre-computed derivatives of the samples.
u: Control variables
Returns:
Tuple of updated x, t, x_dot, u
whether data has multiple trajectories
Raises:
TypeError if data contains a mix of single/multiple trajectories
ValueError if either data different numbers of trajectories
"""
SequenceOrNone = Union[Sequence, None]
mixed_trajectories = (
isinstance(x, Sequence)
and (not isinstance(x_dot, SequenceOrNone) or not isinstance(u, SequenceOrNone))
or isinstance(x_dot, Sequence)
and not isinstance(x, Sequence)
or isinstance(u, Sequence)
and not isinstance(x, Sequence)
)
if mixed_trajectories:
raise TypeError(
"If x, x_dot, or u are a Sequence of trajectories, each must be a Sequence"
" of trajectories or None."
)
if isinstance(x, Sequence):
raise ValueError(
"x is a Sequence, but multiple_trajectories not set. "
"Did you mean to set multiple trajectories?"
matching_lengths = (x_dot is None or len(x) == len(x_dot)) and (
u is None or len(x) == len(u)
)
if not matching_lengths:
raise ValueError("x, x_dot and/or u have mismatched number of trajectories")
return True
return False


def _adapt_to_multiple_trajectories(x, t, x_dot, u) -> tuple:
"""Adapt model data to that multiple_trajectories.
Args:
x: Samples from which to make predictions.
t: Time step between samples or array of collection times.
x_dot: Pre-computed derivatives of the samples.
u: Control variables
Returns:
Tuple of updated x, t, x_dot, u
"""
x = [x]
if isinstance(t, Collection):
t = [t]
Expand Down

0 comments on commit a279e11

Please sign in to comment.