Skip to content

Commit

Permalink
cleanups in fit_curve
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukas Weidenholzer committed Jul 24, 2023
1 parent f716b87 commit a55decd
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions openeo_processes_dask/process_implementations/ml/curve_fitting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from openeo_processes_dask.process_implementations.data_model import RasterCube
from openeo_processes_dask.process_implementations.exceptions import (
DimensionNotAvailable,
Expand All @@ -6,15 +8,21 @@
__all__ = ["fit_curve", "predict_curve"]


def fit_curve(data: RasterCube, parameters: list, function, dimension):
parameters = {f"param_{i}": v for i, v in enumerate(parameters)}

def fit_curve(data: RasterCube, parameters: list, function: Callable, dimension: str):
if dimension not in data.dims:
raise DimensionNotAvailable(

Check warning on line 13 in openeo_processes_dask/process_implementations/ml/curve_fitting.py

View check run for this annotation

Codecov / codecov/patch

openeo_processes_dask/process_implementations/ml/curve_fitting.py#L13

Added line #L13 was not covered by tests
f"Provided dimension ({dimension}) not found in data.dims: {data.dims}"
)

# In the spec, parameters is a list, but xr.curvefit requires names for them,
# so we do this to generate names locally
parameters = {f"param_{i}": v for i, v in enumerate(parameters)}

# The dimension along which to predict cannot be chunked!
rechunked_data = data.chunk({dimension: -1})

# .curvefit returns some extra information that isn't required by the OpenEO process
# so we simply drop these here.
fit_result = rechunked_data.curvefit(
dimension, function, p0=parameters, param_names=list(parameters.keys())
).drop_dims(["cov_i", "cov_j"])
Expand Down

0 comments on commit a55decd

Please sign in to comment.