From da7e624b96e7f8f42aec475204067c91fe112db2 Mon Sep 17 00:00:00 2001 From: PavanSiligam Date: Sat, 20 Jul 2024 08:21:32 +0200 Subject: [PATCH] added support for xr.DataArray --- src/pymorize/units.py | 143 ++++++++------------------- tests/test_units.py | 221 +++++++++++++++++++++++------------------- 2 files changed, 159 insertions(+), 205 deletions(-) diff --git a/src/pymorize/units.py b/src/pymorize/units.py index de70bc1..f591376 100644 --- a/src/pymorize/units.py +++ b/src/pymorize/units.py @@ -12,126 +12,61 @@ import re from typing import Pattern -import pint +import cf_xarray.units +import pint_xarray import xarray as xr from chemicals import periodic_table from loguru import logger -ureg = pint.UnitRegistry() -ureg.define("degC = degree_Celsius") -# https://ncics.org/portfolio/other-resources/udunits2/ -ureg.define("degrees_east = deg") -ureg.define("degree_east = deg") -ureg.define("degrees_north = deg") -ureg.define("degree_north = deg") -ureg.define("degrees_west = -1 * deg") -ureg.define("degrees_south = -1 * deg") -# chemicals -# https://github.com/CalebBell/chemicals/ -#ureg.define(f"molN = {periodic_table.N.MW} * g") -#ureg.define(f"molC = {periodic_table.C.MW} * g") -#ureg.define(f"molFe = {periodic_table.Fe.MW} * g") +ureg = pint_xarray.unit_registry def handle_chemicals(s: str, pattern: Pattern = re.compile(r"mol(?P\w+)")): """Registers known chemical elements definitions to global ureg (unit registry)""" - match = pattern.search(s) + try: + match = pattern.search(s) + except TypeError: + return if match: d = match.groupdict() try: - element = getattr(periodic_table, d['symbol']) + element = getattr(periodic_table, d["symbol"]) except AttributeError: - raise ValueError(f"Unknown chemical element {d.groupdict()['symbol']} in {d.group()}") + raise ValueError( + f"Unknown chemical element {d.groupdict()['symbol']} in {d.group()}" + ) else: - ureg.define(f"{match.group()} = {element.MW} * g") - - - -def _normalize_exponent_notation( - s: str, pattern: Pattern = re.compile(r"(?P\w+)-(?P\d+)") -): - """Converts a string with exponents written as 'name-exp' into a more readable - exponent notation 'name^-exp'. - Example: 'm-2' gets converted as m^-2" - """ - - def correction(match): - try: - float(match.group()) - except ValueError: - d = match.groupdict() - s = f"{d['name']}^-{d['exp']}" - return s - return match.group() - - return re.sub(pattern, correction, s) - - -def _normalize_power_notation( - s: str, pattern: Pattern = re.compile(r"(?P\w+)(?P\d+)") -): - """Converts a string with exponents written as 'nameexp' into a more readable - exponent notation 'name^exp'. - Example: 'm2' gets converted as m^2" - """ - - def correction(match): - try: - float(match.group()) - except ValueError: - d = match.groupdict() - s = f"{d['name']}^{d['exp']}" - if d["exp"] == 1: - s = f"{d['name']}" - return s - return match.group() - - return re.sub(pattern, correction, s) - + try: + ureg(s) + except pint_xarray.pint.errors.UndefinedUnitError: + logger.debug(f"Chemical element {element.name} detected in units {s}.") + logger.debug( + f"Registering definition: {match.group()} = {element.MW} * g" + ) + ureg.define(f"{match.group()} = {element.MW} * g") -def to_caret_notation(unit): - "Formats the unit so Pint can understand them" - return _normalize_power_notation(_normalize_exponent_notation(unit)) +def handle_unit_conversion(da: xr.DataArray, unit: str, source_unit: str = None): + """Performs the unit-aware data conversion. -def calculate_unit_conversion_factor(a: str, b: str) -> float: - """ - Returns the factor required to convert from unit "a" to unit "b" - """ - try: - A = ureg(a) - except (pint.errors.DimensionalityError, pint.errors.UndefinedUnitError): - handle_chemicals(a) - A = to_caret_notation(a) - A = ureg(A) - try: - B = ureg(b) - except (pint.errors.DimensionalityError, pint.errors.UndefinedUnitError): - handle_chemicals(b) - B = to_caret_notation(b) - B = ureg(B) - logger.debug(A) - logger.debug(B) - return A.to(B).magnitude - + If `source_unit` is provided, it is used instead of the unit from DataArray. -def handle_unit_conversion(da: xr.DataArray, to_units: str, from_units: str=None, **kwargs) -> xr.DataArray: - """Does the unit conversion by applying the conversion factor. Parameters: - ---------- - `from_units`: source units. If not provided will be read from DataArray. - `to_units`: target units + ----------- + da: xr.DataArray + unit: unit to convert data to + source_unit: Override the unit on xr.DataArray if needed. """ - if from_units is None: - from_units = getattr(da, 'units') - factor = calculate_unit_conversion_factor(from_units, to_units) - if factor != 1: - da = da * factor - # do we need to set `to_units` here on the data array - da.units = to_units - return da - - -def is_equal(a: str, b: str): - "check if both 'a' and 'b' are equal" - return ureg(to_caret_notation(a)) == ureg(to_caret_notation(b)) + from_unit = da.attrs.get("units") + if source_unit: + logger.debug( + f"using user defined unit ({source_unit}) instead of ({from_unit}) from DataArray " + ) + from_unit = source_unit + handle_chemicals(from_unit) + handle_chemicals(unit) + new_da = da.pint.quantify(from_unit) + new_da = new_da.pint.to(unit).pint.dequantify() + logger.debug(f"setting units on DataArray: {unit}") + new_da.attrs["units"] = unit + return new_da diff --git a/tests/test_units.py b/tests/test_units.py index 561284c..7fc4783 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -1,122 +1,111 @@ -import pint import pytest + +# import cf_xarray.units +# import pint_xarray +import xarray as xr +import pint +import numpy as np from chemicals import periodic_table +from pymorize.units import ( + handle_unit_conversion, + ureg, + handle_chemicals, +) -from pymorize.units import (calculate_unit_conversion_factor, - to_caret_notation, ureg) # input samples that are found in CMIP6 tables and in fesom1 (recom) allunits = [ - # input, expected - ("%", "%"), - ("(mol/kg) / atm", "(mol/kg) / atm"), - ("0.001", "0.001"), - ("1", "1"), - ("1.e6 J m-1 s-1", "1.e^6 J m^-1 s^-1"), - ("1e-06", "1e-06"), - ("1e-09", "1e-09"), - ("1e-12", "1e-12"), - ("1e-3 kg m-2", "1e-3 kg m^-2"), - ("1e-6 m s-1", "1e-6 m s^-1"), - ("1e3 km3", "1e3 km^3"), - ("1e6 km2", "1e6 km^2"), - ("J m-2", "J m^-2"), - ("K", "K"), - ("K Pa s-1", "K Pa s^-1"), - ("K m s-1", "K m s^-1"), - ("K s-1", "K s^-1"), - ("K2", "K^2"), - ("N m-1", "N m^-1"), - ("N m-2", "N m^-2"), - ("Pa", "Pa"), - ("Pa m s-2", "Pa m s^-2"), - ("Pa s-1", "Pa s^-1"), - ("Pa2 s-2", "Pa^2 s^-2"), - ("W", "W"), - ("W m-2", "W m^-2"), - ("W/m2", "W/m^2"), - ("day", "day"), - ("degC", "degC"), - ("degC kg m-2", "degC kg m^-2"), - ("degC2", "degC^2"), - ("degree", "degree"), - ("degrees_east", "degrees_east"), - ("degrees_north", "degrees_north"), - ("kg", "kg"), - ("kg kg-1", "kg kg^-1"), - ("kg m-1 s-1", "kg m^-1 s^-1"), - ("kg m-2", "kg m^-2"), - ("kg m-2 s-1", "kg m^-2 s^-1"), - ("kg m-3", "kg m^-3"), - ("kg s-1", "kg s^-1"), - ("km-2 s-1", "km^-2 s^-1"), - ("m", "m"), - ("m s-1", "m s^-1"), - ("m s-1 d-1", "m s^-1 d^-1"), - ("m s-2", "m s^-2"), - ("m-1", "m^-1"), - ("m-1 sr-1", "m^-1 sr^-1"), - ("m-2", "m^-2"), - ("m-3", "m^-3"), - ("m2", "m^2"), - ("m2 s-1", "m^2 s^-1"), - ("m2 s-2", "m^2 s^-2"), - ("m3", "m^3"), - ("m3 s-1", "m^3 s^-1"), - ("m3 s-2", "m^3 s^-2"), - ("m4 s-1", "m^4 s^-1"), - ("mmol/m2", "mmol/m^2"), - ("mmol/m2/d", "mmol/m^2/d"), - ("mmolC/(m2*d)", "mmolC/(m^2*d)"), - ("mmolC/(m3*d)", "mmolC/(m^3*d)"), - ("mmolC/d", "mmolC/d"), - ("mmolC/m2/d", "mmolC/m^2/d"), - ("mmolN/(m2*d)", "mmolN/(m^2*d)"), - ("mmolN/d", "mmolN/d"), - ("mmolN/m2/s", "mmolN/m^2/s"), - ("mol m-2", "mol m^-2"), - ("mol m-2 s-1", "mol m^-2 s^-1"), - ("mol m-3", "mol m^-3"), - ("mol m-3 s-1", "mol m^-3 s^-1"), - ("mol mol-1", "mol mol^-1"), - ("mol s-1", "mol s^-1"), - ("mol/kg", "mol/kg"), - ("s", "s"), - ("s m-1", "s m^-1"), - ("s-1", "s^-1"), - ("s-2", "s^-2"), - ("uatm", "uatm"), - ("umolFe/m2/s", "umolFe/m^2/s"), - ("year", "year"), - ("yr", "yr"), + "%", + "0.001", + "1", + "1.e6 J m-1 s-1", + "1e-06", + "1e-3 kg m-2", + "1e3 km3", + "J m-2", + "K", + "K Pa s-1", + "K s-1", + "K2", + "Pa2 s-2", + "W m^-2", + "W/m2", + "W/m^2", + "day", + "degC", + "degC kg m-2", + "degC2", + "degree", + "degrees_east", + "degrees_north", + "kg kg-1", + "kg m-2 s-1", + "kg m-3", + "kg s-1", + "km-2 s-1", + "m-1 sr-1", + "m-2", + "m^-3", + "m^2", + "mol/kg", + "mol/m2", + "mol m-2", + "mol m^-2", + "(mol/kg) / atm", + "mmol/m2/d", + "uatm", + "year", + "yr", ] -@pytest.mark.parametrize("test_input,expected", allunits) -def test_can_convert_SI_notation_to_caret_notation(test_input, expected): - u = to_caret_notation(test_input) - assert u == expected +@pytest.mark.parametrize("test_input", allunits) +def test_can_read_units(test_input): + ureg(test_input) -mixed_notation_to_slash = [("mmolC/m2/d", "mmolC/m^2/d")] +units_with_chemical_element = [ + "mmolC/(m2*d)", + "mmolC/d", + "mmolC/m2/d", + "mmolN/(m2*d)", + "mmolN/d", + "umolFe/m2/s", +] -@pytest.mark.parametrize("test_input,expected", mixed_notation_to_slash) -def test_can_convert_mixed_notation_to_caret_notation(test_input, expected): - u = to_caret_notation(test_input) - assert u == expected +@pytest.mark.parametrize("test_input", units_with_chemical_element) +def test_handle_chemicals(test_input): + handle_chemicals(test_input) + ureg(test_input) -def test_can_convert_to_different_units(): +def test_can_handle_simple_chemical_elements(): + from_unit = "molC" + to_unit = "g" + da = xr.DataArray(10, attrs={"units": from_unit}) + new_da = handle_unit_conversion(da, to_unit) + assert new_da.data == np.array(periodic_table.Carbon.MW * 10) + assert new_da.attrs["units"] == to_unit + + +def test_can_handle_chemical_elements(): from_unit = "mmolC/m2/d" to_unit = "kg m-2 s-1" - factor = calculate_unit_conversion_factor(from_unit, to_unit) - assert factor == 1.3901273148148146e-10 + da = xr.DataArray(10, attrs={"units": from_unit}) + new_da = handle_unit_conversion(da, to_unit) + assert np.allclose(new_da.data, np.array(1.39012731e-09)) + assert new_da.attrs["units"] == to_unit -def test_non_caret_notation_raises_error(): - with pytest.raises(pint.errors.DimensionalityError): - ureg("kg m-2 s-1") +def test_user_defined_units_takes_precedence_over_units_in_dataarray(): + from_unit = "molC" + to_unit = "g" + da = xr.DataArray(10, attrs={"units": "kg"}) + # here, "molC" will be used instead of "kg" + new_da = handle_unit_conversion(da, to_unit, from_unit) + assert new_da.data == np.array(periodic_table.Carbon.MW * 10) + assert new_da.attrs["units"] == to_unit def test_without_defining_uraninum_to_weight_conversion_raises_error(): @@ -125,5 +114,35 @@ def test_without_defining_uraninum_to_weight_conversion_raises_error(): ureg("mmolU/m**2/d") -def test_define_carbon_to_weight_conversion(): +def test_recognizes_previous_defined_chemical_elements(): assert "mmolC/m^2/d" in ureg + + +def test_works_when_both_units_are_None(): + to_unit = None + da = xr.DataArray(10, attrs={"units": None}) + new_da = handle_unit_conversion(da, to_unit) + assert new_da.attrs["units"] == to_unit + + +def test_works_when_both_units_are_empty_string(): + to_unit = "" + da = xr.DataArray(10, attrs={"units": ""}) + new_da = handle_unit_conversion(da, to_unit) + assert new_da.attrs["units"] == to_unit + + +@pytest.mark.parametrize("from_unit", ["m/s", None, ""]) +def test_when_target_units_is_None_overrides_existing_units(from_unit): + to_unit = None + da = xr.DataArray(10, attrs={"units": from_unit}) + new_da = handle_unit_conversion(da, to_unit) + assert new_da.attrs["units"] == to_unit + + +@pytest.mark.parametrize("from_unit", ["m/s", None]) +def test_when_tartget_unit_is_empty_string_raises_error(from_unit): + to_unit = "" + da = xr.DataArray(10, attrs={"units": from_unit}) + with pytest.raises(ValueError): + handle_unit_conversion(da, to_unit)