diff --git a/pypots/imputation/lerp/__init__.py b/pypots/imputation/lerp/__init__.py index 0ca166fc..2d5c1155 100644 --- a/pypots/imputation/lerp/__init__.py +++ b/pypots/imputation/lerp/__init__.py @@ -9,4 +9,4 @@ __all__ = [ "Lerp", -] \ No newline at end of file +] diff --git a/pypots/imputation/lerp/model.py b/pypots/imputation/lerp/model.py index 2f39dcdf..ffdd60db 100644 --- a/pypots/imputation/lerp/model.py +++ b/pypots/imputation/lerp/model.py @@ -19,10 +19,11 @@ class Lerp(BaseImputer): """Linear interpolation (Lerp) imputation method. Lerp will linearly interpolate missing values between the nearest non-missing values. - If there are missing values at the beginning or end of the series, they will be back-filled or forward-filled with the nearest non-missing value, respectively. + If there are missing values at the beginning or end of the series, they will be back-filled or + forward-filled with the nearest non-missing value, respectively. If an entire series is empty, all 'nan' values will be filled with zeros. """ - + def __init__( self, ): @@ -95,14 +96,14 @@ def _interpolate_missing_values(X: np.ndarray): X[nans] = np.interp(nan_index, index, X[~nans]) elif np.any(nans): X[nans] = 0 - + if isinstance(X, np.ndarray): trans_X = X.transpose((0, 2, 1)) n_samples, n_features, n_steps = trans_X.shape reshaped_X = np.reshape(trans_X, (-1, n_steps)) imputed_X = np.ones(reshaped_X.shape) - + for i, univariate_series in enumerate(reshaped_X): t = np.copy(univariate_series) _interpolate_missing_values(t) @@ -133,7 +134,7 @@ def _interpolate_missing_values(X: np.ndarray): "imputation": imputed_data, } return result_dict - + def impute( self, test_set: Union[dict, str], @@ -157,4 +158,4 @@ def impute( """ result_dict = self.predict(test_set, file_type=file_type) - return result_dict["imputation"] \ No newline at end of file + return result_dict["imputation"] diff --git a/tests/imputation/lerp.py b/tests/imputation/lerp.py index d30909e8..41b396d6 100644 --- a/tests/imputation/lerp.py +++ b/tests/imputation/lerp.py @@ -71,4 +71,4 @@ def test_4_lazy_loading(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()