Skip to content

Commit

Permalink
chore: remove deprecated .drop() function in drop dimension (#136)
Browse files Browse the repository at this point in the history
* remove deprecated .drop() in drop_dimension

* add tests for drop_dimension
  • Loading branch information
LukeWeidenwalker authored Jul 14, 2023
1 parent afca7c6 commit 6892cf0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def drop_dimension(data: RasterCube, name: str) -> RasterCube:
raise DimensionLabelCountMismatch(
f"The number of dimension labels exceeds one, which requires a reducer. Dimension ({name}) has {len(data[name])} labels."
)
return data.drop(name)
return data.drop_vars(name).squeeze()


def create_raster_cube() -> RasterCube:
Expand Down
33 changes: 32 additions & 1 deletion tests/test_dimensions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import numpy as np
import pytest

from openeo_processes_dask.process_implementations.cubes.general import add_dimension
from openeo_processes_dask.process_implementations.cubes.general import (
add_dimension,
drop_dimension,
)
from openeo_processes_dask.process_implementations.exceptions import (
DimensionLabelCountMismatch,
DimensionNotAvailable,
)
from tests.general_checks import general_output_checks
from tests.mockdata import create_fake_rastercube

Expand Down Expand Up @@ -33,3 +40,27 @@ def test_add_dimension(temporal_interval, bounding_box, random_raster_data):
data=input_cube, name="weird", label="test", type="temporal"
)
assert output_cube_2.openeo.temporal_dims[1] == "weird"


@pytest.mark.parametrize("size", [(30, 30, 20, 2)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_drop_dimension(temporal_interval, bounding_box, random_raster_data):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B04"],
backend="dask",
)
DIM_TO_DROP = "bands"

with pytest.raises(DimensionNotAvailable):
drop_dimension(input_cube, "notthere")

with pytest.raises(DimensionLabelCountMismatch):
drop_dimension(input_cube, DIM_TO_DROP)

suitable_cube = input_cube.where(input_cube.bands == "B02", drop=True)

output_cube = drop_dimension(suitable_cube, DIM_TO_DROP)
assert DIM_TO_DROP not in output_cube.dims

0 comments on commit 6892cf0

Please sign in to comment.