From a07fe99caf3fb9a0579dd96e0f7822c9cb9ebe71 Mon Sep 17 00:00:00 2001 From: PavanSiligam Date: Fri, 20 Sep 2024 14:49:42 +0200 Subject: [PATCH 1/2] fixed the function signature for units handling. --- src/pymorize/units.py | 18 ++++++------------ tests/unit/test_units.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/pymorize/units.py b/src/pymorize/units.py index 83cac08..0109a39 100644 --- a/src/pymorize/units.py +++ b/src/pymorize/units.py @@ -86,12 +86,7 @@ def handle_chemicals( ureg.define(f"{match.group()} = {element.MW} * g") -# FIXME: This needs to have a different signature! -def handle_unit_conversion( - da: xr.DataArray, - rule_spec: Rule, - source_unit: Union[str, None] = None, -) -> xr.DataArray: +def handle_unit_conversion(da: xr.DataArray, rule: Rule) -> xr.DataArray: """Performs the unit-aware data conversion. If `source_unit` is provided, it is used instead of the unit from DataArray. @@ -101,8 +96,6 @@ def handle_unit_conversion( da: ~xr.DataArray unit: str unit to convert data to - source_unit: str or None - Override the unit on ``da.attrs.unit`` if needed. Returns ------- @@ -112,14 +105,15 @@ def handle_unit_conversion( if not isinstance(da, xr.DataArray): raise TypeError(f"Expected xr.DataArray, got {type(da)}") # data_request_variable needs to be defined at this point - drv = rule_spec.data_request_variable + drv = rule.data_request_variable to_unit = drv.unit + model_unit = rule.get("model_unit") from_unit = da.attrs.get("units") - if source_unit is not None: + if model_unit is not None: logger.debug( - f"using user defined unit ({source_unit}) instead of ({from_unit}) from DataArray " + f"using user defined unit ({model_unit}) instead of ({from_unit}) from DataArray " ) - from_unit = source_unit + from_unit = model_unit handle_chemicals(from_unit) handle_chemicals(to_unit) new_da = da.pint.quantify(from_unit) diff --git a/tests/unit/test_units.py b/tests/unit/test_units.py index 97b4a08..108960f 100644 --- a/tests/unit/test_units.py +++ b/tests/unit/test_units.py @@ -99,11 +99,11 @@ def test_can_handle_chemical_elements(rule_with_units): def test_user_defined_units_takes_precedence_over_units_in_dataarray(rule_with_units): rule_spec = rule_with_units rule_spec.data_request_variable.unit = "g" - from_unit = "molC" + rule_spec.model_unit = "molC" to_unit = "g" da = xr.DataArray(10, attrs={"units": "kg"}) # here, "molC" will be used instead of "kg" - new_da = handle_unit_conversion(da, rule_spec, from_unit) + new_da = handle_unit_conversion(da, rule_spec) assert new_da.data == np.array(periodic_table.Carbon.MW * 10) assert new_da.attrs["units"] == to_unit @@ -121,10 +121,10 @@ def test_recognizes_previous_defined_chemical_elements(): def test_works_when_both_units_are_None(rule_with_units): rule_spec = rule_with_units rule_spec.data_request_variable.unit = None - to_unit = None + rule_spec.model_unit = None da = xr.DataArray(10, attrs={"units": None}) - new_da = handle_unit_conversion(da, rule_spec, to_unit) - assert new_da.attrs["units"] == to_unit + new_da = handle_unit_conversion(da, rule_spec) + assert new_da.attrs["units"] == None def test_works_when_both_units_are_empty_string(rule_with_units): @@ -143,16 +143,16 @@ def test_when_target_units_is_None_overrides_existing_units(rule_with_units, fro drv = rule_spec.data_request_variable if hasattr(drv, "unit"): drv.unit = from_unit - to_unit = None + rule_spec.model_unit = None da = xr.DataArray(10, attrs={"units": from_unit}) - new_da = handle_unit_conversion(da, rule_spec, to_unit) + new_da = handle_unit_conversion(da, rule_spec) assert new_da.attrs["units"] == to_unit @pytest.mark.parametrize("from_unit", ["m/s", None]) def test_when_tartget_unit_is_empty_string_raises_error(rule_with_units, from_unit): rule_spec = rule_with_units - to_unit = "" + rule_spec.model_unit = "" da = xr.DataArray(10, attrs={"units": from_unit}) with pytest.raises(ValueError): - handle_unit_conversion(da, rule_spec, to_unit) + handle_unit_conversion(da, rule_spec) From 291cd84d8bc78c157c69ea86f81e69fc6919d3ca Mon Sep 17 00:00:00 2001 From: PavanSiligam Date: Fri, 20 Sep 2024 14:58:46 +0200 Subject: [PATCH 2/2] fixed typo --- tests/unit/test_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_units.py b/tests/unit/test_units.py index 108960f..1c00594 100644 --- a/tests/unit/test_units.py +++ b/tests/unit/test_units.py @@ -146,7 +146,7 @@ def test_when_target_units_is_None_overrides_existing_units(rule_with_units, fro rule_spec.model_unit = None da = xr.DataArray(10, attrs={"units": from_unit}) new_da = handle_unit_conversion(da, rule_spec) - assert new_da.attrs["units"] == to_unit + assert new_da.attrs["units"] == drv.unit @pytest.mark.parametrize("from_unit", ["m/s", None])