Skip to content

Commit

Permalink
added support for xr.DataArray
Browse files Browse the repository at this point in the history
  • Loading branch information
siligam committed Jul 20, 2024
1 parent aaf0f9b commit da7e624
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 205 deletions.
143 changes: 39 additions & 104 deletions src/pymorize/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,126 +12,61 @@
import re
from typing import Pattern

import pint
import cf_xarray.units
import pint_xarray
import xarray as xr
from chemicals import periodic_table
from loguru import logger

ureg = pint.UnitRegistry()
ureg.define("degC = degree_Celsius")
# https://ncics.org/portfolio/other-resources/udunits2/
ureg.define("degrees_east = deg")
ureg.define("degree_east = deg")
ureg.define("degrees_north = deg")
ureg.define("degree_north = deg")
ureg.define("degrees_west = -1 * deg")
ureg.define("degrees_south = -1 * deg")
# chemicals
# https://github.com/CalebBell/chemicals/
#ureg.define(f"molN = {periodic_table.N.MW} * g")
#ureg.define(f"molC = {periodic_table.C.MW} * g")
#ureg.define(f"molFe = {periodic_table.Fe.MW} * g")
ureg = pint_xarray.unit_registry


def handle_chemicals(s: str, pattern: Pattern = re.compile(r"mol(?P<symbol>\w+)")):
"""Registers known chemical elements definitions to global ureg (unit registry)"""
match = pattern.search(s)
try:
match = pattern.search(s)
except TypeError:
return
if match:
d = match.groupdict()
try:
element = getattr(periodic_table, d['symbol'])
element = getattr(periodic_table, d["symbol"])
except AttributeError:
raise ValueError(f"Unknown chemical element {d.groupdict()['symbol']} in {d.group()}")
raise ValueError(
f"Unknown chemical element {d.groupdict()['symbol']} in {d.group()}"
)
else:
ureg.define(f"{match.group()} = {element.MW} * g")



def _normalize_exponent_notation(
s: str, pattern: Pattern = re.compile(r"(?P<name>\w+)-(?P<exp>\d+)")
):
"""Converts a string with exponents written as 'name-exp' into a more readable
exponent notation 'name^-exp'.
Example: 'm-2' gets converted as m^-2"
"""

def correction(match):
try:
float(match.group())
except ValueError:
d = match.groupdict()
s = f"{d['name']}^-{d['exp']}"
return s
return match.group()

return re.sub(pattern, correction, s)


def _normalize_power_notation(
s: str, pattern: Pattern = re.compile(r"(?P<name>\w+)(?P<exp>\d+)")
):
"""Converts a string with exponents written as 'nameexp' into a more readable
exponent notation 'name^exp'.
Example: 'm2' gets converted as m^2"
"""

def correction(match):
try:
float(match.group())
except ValueError:
d = match.groupdict()
s = f"{d['name']}^{d['exp']}"
if d["exp"] == 1:
s = f"{d['name']}"
return s
return match.group()

return re.sub(pattern, correction, s)

try:
ureg(s)
except pint_xarray.pint.errors.UndefinedUnitError:
logger.debug(f"Chemical element {element.name} detected in units {s}.")
logger.debug(
f"Registering definition: {match.group()} = {element.MW} * g"
)
ureg.define(f"{match.group()} = {element.MW} * g")

def to_caret_notation(unit):
"Formats the unit so Pint can understand them"
return _normalize_power_notation(_normalize_exponent_notation(unit))

def handle_unit_conversion(da: xr.DataArray, unit: str, source_unit: str = None):
"""Performs the unit-aware data conversion.
def calculate_unit_conversion_factor(a: str, b: str) -> float:
"""
Returns the factor required to convert from unit "a" to unit "b"
"""
try:
A = ureg(a)
except (pint.errors.DimensionalityError, pint.errors.UndefinedUnitError):
handle_chemicals(a)
A = to_caret_notation(a)
A = ureg(A)
try:
B = ureg(b)
except (pint.errors.DimensionalityError, pint.errors.UndefinedUnitError):
handle_chemicals(b)
B = to_caret_notation(b)
B = ureg(B)
logger.debug(A)
logger.debug(B)
return A.to(B).magnitude

If `source_unit` is provided, it is used instead of the unit from DataArray.
def handle_unit_conversion(da: xr.DataArray, to_units: str, from_units: str=None, **kwargs) -> xr.DataArray:
"""Does the unit conversion by applying the conversion factor.
Parameters:
----------
`from_units`: source units. If not provided will be read from DataArray.
`to_units`: target units
-----------
da: xr.DataArray
unit: unit to convert data to
source_unit: Override the unit on xr.DataArray if needed.
"""
if from_units is None:
from_units = getattr(da, 'units')
factor = calculate_unit_conversion_factor(from_units, to_units)
if factor != 1:
da = da * factor
# do we need to set `to_units` here on the data array
da.units = to_units
return da


def is_equal(a: str, b: str):
"check if both 'a' and 'b' are equal"
return ureg(to_caret_notation(a)) == ureg(to_caret_notation(b))
from_unit = da.attrs.get("units")
if source_unit:
logger.debug(
f"using user defined unit ({source_unit}) instead of ({from_unit}) from DataArray "
)
from_unit = source_unit
handle_chemicals(from_unit)
handle_chemicals(unit)
new_da = da.pint.quantify(from_unit)
new_da = new_da.pint.to(unit).pint.dequantify()
logger.debug(f"setting units on DataArray: {unit}")
new_da.attrs["units"] = unit
return new_da
Loading

0 comments on commit da7e624

Please sign in to comment.