Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight threshold option for temporal operations #683

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
414 changes: 388 additions & 26 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
@@ -489,14 +489,6 @@ def test_weighted_annual_averages(self):
cftime.DatetimeGregorian(2001, 1, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
)
},
dims=["time"],
attrs={
"axis": "T",
@@ -540,14 +532,6 @@ def test_weighted_annual_averages_with_chunking(self):
cftime.DatetimeGregorian(2001, 1, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
)
},
dims=["time"],
attrs={
"axis": "T",
@@ -571,6 +555,195 @@ def test_weighted_annual_averages_with_chunking(self):
assert result.ts.attrs == expected.ts.attrs
assert result.time.attrs == expected.time.attrs

def test_weighted_annual_averages_with_masked_data_and_min_weight_threshold_of_100_percent(
self,
):
# Set up dataset
ds = xr.Dataset(
coords={
"lat": [-90],
"lon": [0],
"time": xr.DataArray(
data=np.array(
[
"2000-01-01T00:00:00.000000000",
"2000-02-01T00:00:00.000000000",
"2001-01-01T00:00:00.000000000",
"2001-02-01T00:00:00.000000000",
"2002-01-01T00:00:00.000000000",
],
dtype="datetime64[ns]",
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
}
)
ds.time.encoding = {"calendar": "standard"}

ds["time_bnds"] = xr.DataArray(
name="time_bnds",
data=np.array(
[
["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"],
["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"],
["2001-01-01T00:00:00.000000000", "2000-01-01T00:00:00.000000000"],
["2001-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"],
["2002-01-01T00:00:00.000000000", "2002-02-01T00:00:00.000000000"],
],
dtype="datetime64[ns]",
),
coords={"time": ds.time},
dims=["time", "bnds"],
attrs={"xcdat_bounds": "True"},
)

ds["ts"] = xr.DataArray(
data=np.array([[[2]], [[np.nan]], [[1]], [[1]], [[0.5]]]),
coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time},
dims=["time", "lat", "lon"],
)

# NOTE: If a cell has a missing value for any of the years, the average
# for that year should be masked with a min_weight threshold of 100%.
result = ds.temporal.group_average("ts", "year", min_weight=1.0)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[np.nan]], [[1]], [[0.5]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
cftime.DatetimeGregorian(2002, 1, 1),
],
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "year",
"weighted": "True",
},
)

xr.testing.assert_allclose(result, expected)

def test_weighted_annual_averages_with_masked_data_and_min_weight_threshold_of_50_percent(
self,
):
# Set up dataset
ds = xr.Dataset(
coords={
"lat": [-90],
"lon": [0],
"time": xr.DataArray(
data=np.array(
[
"2000-01-01T00:00:00.000000000",
"2000-02-01T00:00:00.000000000",
"2001-01-01T00:00:00.000000000",
"2001-02-01T00:00:00.000000000",
"2002-01-01T00:00:00.000000000",
],
dtype="datetime64[ns]",
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
}
)
ds.time.encoding = {"calendar": "standard"}

ds["time_bnds"] = xr.DataArray(
name="time_bnds",
data=np.array(
[
["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"],
["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"],
["2001-01-01T00:00:00.000000000", "2000-01-01T00:00:00.000000000"],
["2001-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"],
["2002-01-01T00:00:00.000000000", "2002-02-01T00:00:00.000000000"],
],
dtype="datetime64[ns]",
),
coords={"time": ds.time},
dims=["time", "bnds"],
attrs={"xcdat_bounds": "True"},
)

ds["ts"] = xr.DataArray(
data=np.array([[[2]], [[np.nan]], [[1]], [[1]], [[0.5]]]),
coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time},
dims=["time", "lat", "lon"],
)

# NOTE: The second cell of "ts" has missing data, but the first cell
# has more weight (due to more days in the month of Jan vs. Feb) so the
# average for the year is not masked with a min_weight threshold of 50%.
result = ds.temporal.group_average("ts", "year", min_weight=0.50)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[2.0]], [[1]], [[0.5]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
cftime.DatetimeGregorian(2002, 1, 1),
],
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "year",
"weighted": "True",
},
)

xr.testing.assert_allclose(result, expected)

def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self):
ds = self.ds.copy()

@@ -625,6 +798,14 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons
self,
):
ds = self.ds.copy()
ds["ts"] = xr.DataArray(
data=np.array(
[[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]], dtype="float64"
),
coords={"time": self.ds.time, "lat": self.ds.lat, "lon": self.ds.lon},
dims=["time", "lat", "lon"],
attrs={"test_attr": "test"},
)

result = ds.temporal.group_average(
"ts",
@@ -698,17 +879,102 @@ def test_weighted_seasonal_averages_with_JFD(self):
cftime.DatetimeGregorian(2001, 1, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2000, 4, 1),
cftime.DatetimeGregorian(2000, 7, 1),
cftime.DatetimeGregorian(2000, 10, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
)
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "season",
"weighted": "True",
"dec_mode": "JFD",
},
)

xr.testing.assert_identical(result, expected)

def test_weighted_seasonal_averages_with_JFD_with_min_weight_threshold_of_100_percent(
self,
):
time = xr.DataArray(
data=np.array(
[
"2000-01-16T12:00:00.000000000",
"2000-02-15T12:00:00.000000000",
"2000-03-16T12:00:00.000000000",
"2000-06-16T00:00:00.000000000",
"2000-12-16T00:00:00.000000000",
],
dtype="datetime64[ns]",
),
dims=["time"],
attrs={"axis": "T", "long_name": "time", "standard_name": "time"},
)
time.encoding = {"calendar": "standard"}
time_bnds = xr.DataArray(
name="time_bnds",
data=np.array(
[
["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"],
["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"],
["2000-03-01T00:00:00.000000000", "2000-04-01T00:00:00.000000000"],
["2000-06-01T00:00:00.000000000", "2000-07-01T00:00:00.000000000"],
["2000-11-01T00:00:00.000000000", "2001-01-01T00:00:00.000000000"],
],
dtype="datetime64[ns]",
),
coords={"time": time},
dims=["time", "bnds"],
attrs={"xcdat_bounds": "True"},
)

ds = xr.Dataset(
data_vars={"time_bnds": time_bnds},
coords={"lat": [-90], "lon": [0], "time": time},
)
ds.time.attrs["bounds"] = "time_bnds"

ds["ts"] = xr.DataArray(
data=np.array(
[[[np.nan]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]], dtype="float64"
),
coords={"time": self.ds.time, "lat": self.ds.lat, "lon": self.ds.lon},
dims=["time", "lat", "lon"],
attrs={"test_attr": "test"},
)

# NOTE: If a cell has a missing value for any of the seasons, the average
# for that season should be masked with a min_weight threshold of 100%.
result = ds.temporal.group_average(
"ts",
"season",
season_config={"dec_mode": "JFD"},
min_weight=1.0,
)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[np.nan]], [[1.0]], [[1.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2000, 4, 1),
cftime.DatetimeGregorian(2000, 7, 1),
],
),
dims=["time"],
attrs={
"axis": "T",
@@ -926,6 +1192,102 @@ def test_weighted_monthly_averages_with_masked_data(self):

xr.testing.assert_identical(result, expected)

def test_weighted_monthly_averages_with_masked_data_and_min_weight_threshold_of_100_percent(
self,
):
# Set up dataset
ds = xr.Dataset(
coords={
"lat": [-90],
"lon": [0],
"time": xr.DataArray(
data=np.array(
[
"2000-01-01T00:00:00.000000000",
"2000-02-01T00:00:00.000000000",
"2000-02-15T00:00:00.000000000",
"2000-04-01T00:00:00.000000000",
"2001-02-01T00:00:00.000000000",
],
dtype="datetime64[ns]",
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
}
)
ds.time.encoding = {"calendar": "standard"}

ds["time_bnds"] = xr.DataArray(
name="time_bnds",
data=np.array(
[
["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"],
["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"],
["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"],
["2000-04-01T00:00:00.000000000", "2000-05-01T00:00:00.000000000"],
["2001-02-01T00:00:00.000000000", "2001-03-01T00:00:00.000000000"],
],
dtype="datetime64[ns]",
),
coords={"time": ds.time},
dims=["time", "bnds"],
attrs={"xcdat_bounds": "True"},
)

ds["ts"] = xr.DataArray(
data=np.array([[[2]], [[np.nan]], [[1]], [[1]], [[1]]]),
coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time},
dims=["time", "lat", "lon"],
attrs={"test_attr": "test"},
)

# NOTE: If a cell has a missing value for any of the months, the average
# for that month should be masked with a min_weight threshold of 100%.
result = ds.temporal.group_average("ts", "month", min_weight=0.55)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[2.0]], [[np.nan]], [[1.0]], [[1.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2000, 2, 1),
cftime.DatetimeGregorian(2000, 4, 1),
cftime.DatetimeGregorian(2001, 2, 1),
],
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "month",
"weighted": "True",
},
)

xr.testing.assert_identical(result, expected)

def test_weighted_daily_averages(self):
ds = self.ds.copy()

22 changes: 21 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import xarray as xr

from xcdat.utils import compare_datasets, str_to_bool
from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool


class TestCompareDatasets:
@@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self):

with pytest.raises(ValueError):
str_to_bool("1")


class TestValidateMinWeight:
def test_pass_None_returns_0(self):
result = _validate_min_weight(None)

assert result == 0

def test_returns_error_if_less_than_0(self):
with pytest.raises(ValueError):
_validate_min_weight(-1)

def test_returns_error_if_greater_than_1(self):
with pytest.raises(ValueError):
_validate_min_weight(1.1)

def test_returns_valid_min_weight(self):
result = _validate_min_weight(1)

assert result == 1
2 changes: 2 additions & 0 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Module containing geospatial averaging functions."""
from __future__ import annotations

from functools import reduce
from typing import (
Callable,
92 changes: 62 additions & 30 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Module containing temporal functions."""
from __future__ import annotations

from datetime import datetime
from itertools import chain
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union, get_args
@@ -17,6 +19,7 @@
from xcdat._logger import _setup_custom_logger
from xcdat.axis import get_dim_coords
from xcdat.dataset import _get_data_var
from xcdat.utils import _get_masked_weights, _validate_min_weight

logger = _setup_custom_logger(__name__)

@@ -154,7 +157,12 @@ class TemporalAccessor:
def __init__(self, dataset: xr.Dataset):
self._dataset: xr.Dataset = dataset

def average(self, data_var: str, weighted: bool = True, keep_weights: bool = False):
def average(
self,
data_var: str,
weighted: bool = True,
keep_weights: bool = False,
):
"""
Returns a Dataset with the average of a data variable and the time
dimension removed.
@@ -230,7 +238,11 @@ def average(self, data_var: str, weighted: bool = True, keep_weights: bool = Fal
freq = _infer_freq(self._dataset[self.dim])

return self._averager(
data_var, "average", freq, weighted=weighted, keep_weights=keep_weights
data_var,
"average",
freq,
weighted=weighted,
keep_weights=keep_weights,
)

def group_average(
@@ -240,6 +252,7 @@ def group_average(
weighted: bool = True,
keep_weights: bool = False,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
min_weight: float | None = None,
):
"""Returns a Dataset with average of a data variable by time group.
@@ -318,6 +331,10 @@ def group_average(
>>> ["Jul", "Aug", "Sep"], # "JulAugSep"
>>> ["Oct", "Nov", "Dec"], # "OctNovDec"
>>> ]
min_weight : float | None, optional
Fraction of data coverage (i..e, weight) needed to return a
temporal average value. Value must range from 0 to 1, by default
None (equivalent to ``min_weight=0.0``).
Returns
-------
@@ -396,6 +413,7 @@ def group_average(
weighted=weighted,
keep_weights=keep_weights,
season_config=season_config,
min_weight=min_weight,
)

def climatology(
@@ -798,10 +816,13 @@ def _averager(
keep_weights: bool = False,
reference_period: Optional[Tuple[str, str]] = None,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
min_weight: float | None = None,
) -> xr.Dataset:
"""Averages a data variable based on the averaging mode and frequency."""
ds = self._dataset.copy()
self._set_arg_attrs(mode, freq, weighted, reference_period, season_config)
self._set_arg_attrs(
mode, freq, weighted, reference_period, season_config, min_weight
)

# Preprocess the dataset based on method argument values.
ds = self._preprocess_dataset(ds)
@@ -815,7 +836,7 @@ def _averager(
# it becomes obsolete after the data variable is averaged. When the
# averaged data variable is added to the dataset, the new time dimension
# and its associated coordinates are also added.
ds = ds.drop_dims(self.dim) # type: ignore
ds = ds.drop_dims(self.dim)
ds[dv_avg.name] = dv_avg

if keep_weights:
@@ -847,7 +868,7 @@ def _set_data_var_attrs(self, data_var: str):
dv = _get_data_var(self._dataset, data_var)

self.data_var = data_var
self.dim = get_dim_coords(dv, "T").name
self.dim = str(get_dim_coords(dv, "T").name)

if not _contains_datetime_like_objects(dv[self.dim]):
first_time_coord = dv[self.dim].values[0]
@@ -882,6 +903,7 @@ def _set_arg_attrs(
weighted: bool,
reference_period: Optional[Tuple[str, str]] = None,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
min_weight: float | None = None,
):
"""Validates method arguments and sets them as object attributes.
@@ -897,6 +919,10 @@ def _set_arg_attrs(
A dictionary for "season" frequency configurations. If configs for
predefined seasons are passed, configs for custom seasons are
ignored and vice versa, by default DEFAULT_SEASON_CONFIG.
min_weight : float | None, optional
Fraction of data coverage (i..e, weight) needed to return a
temporal average value. Value must range from 0 to 1, by default
None (equivalent to ``min_weight=0.0``).
Raises
------
@@ -924,6 +950,7 @@ def _set_arg_attrs(
self._mode = mode
self._freq = freq
self._weighted = weighted
self._min_weight = _validate_min_weight(min_weight)

self._reference_period = None
if reference_period is not None:
@@ -1115,9 +1142,7 @@ def _drop_leap_days(self, ds: xr.Dataset):
-------
xr.Dataset
"""
ds = ds.sel( # type: ignore
**{self.dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))}
)
ds = ds.sel(**{self.dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))})
return ds

def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
@@ -1142,9 +1167,9 @@ def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

dv = dv.weighted(self._weights).mean(dim=self.dim) # type: ignore
dv = dv.weighted(self._weights).mean(dim=self.dim)
else:
dv = dv.mean(dim=self.dim) # type: ignore
dv = dv.mean(dim=self.dim)

dv = self._add_operation_attrs(dv)

@@ -1176,37 +1201,44 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

# Weight the data variable.
dv *= self._weights

# Ensure missing data (`np.nan`) receives no weight (zero). To
# achieve this, first broadcast the one-dimensional (temporal
# dimension) shape of the `weights` DataArray to the
# multi-dimensional shape of its corresponding data variable.
weights, _ = xr.broadcast(self._weights, dv)
weights = xr.where(dv.copy().isnull(), 0.0, weights)

# Perform weighted average using the formula
# WA = sum(data*weights) / sum(weights). The denominator must be
# included to take into account zero weight for missing data.
with xr.set_options(keep_attrs=True):
dv = self._group_data(dv).sum() / self._group_data(weights).sum()
dv_weighted = dv * self._weights

# Perform weighted average using the formula
# # WA = sum(data*weights) / sum(masked weights).
# The denominator must be included to take into account zero
# weight for missing data.
dv_group_sum = self._group_data(dv_weighted).sum()
weights_masked = _get_masked_weights(dv_weighted, self._weights)
weights_masked_group_sum = self._group_data(weights_masked).sum()

dv_avg = dv_group_sum / weights_masked_group_sum

# Mask the data variable values with weights below the minimum
# weight threshold (if specified).
if self._min_weight > 0.0:
dv_avg = xr.where(
weights_masked_group_sum >= self._min_weight,
dv_avg,
np.nan,
keep_attrs=True,
)

# Restore the data variable's name.
dv.name = data_var
dv_avg.name = data_var
else:
dv = self._group_data(dv).mean()
dv_avg = self._group_data(dv).mean()

# After grouping and aggregating, the grouped time dimension's
# attributes are removed. Xarray's `keep_attrs=True` option only keeps
# attributes for data variables and not their coordinates, so the
# coordinate attributes have to be restored manually.
dv[self.dim].attrs = self._labeled_time.attrs
dv[self.dim].encoding = self._labeled_time.encoding
dv_avg[self.dim].attrs = self._labeled_time.attrs
dv_avg[self.dim].encoding = self._labeled_time.encoding

dv = self._add_operation_attrs(dv)
dv_avg = self._add_operation_attrs(dv_avg)

return dv
return dv_avg

def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
"""Calculates weights for a data variable using time bounds.
60 changes: 60 additions & 0 deletions xcdat/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib
import json
from typing import Dict, List, Optional, Union
@@ -132,3 +134,61 @@ def _if_multidim_dask_array_then_load(
return obj.load()

return None


def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray:
"""Get weights with missing data (`np.nan`) receiving no weight (zero).
Parameters
----------
dv : xr.DataArray
The variable.
weights : xr.DataArray
A DataArray containing either the regional or temporal weights used for
weighted averaging. ``weights`` must include the same axis dimensions
and dimensional sizes as the data variable.
Returns
-------
xr.DataArray
The masked weights.
"""
masked_weights = xr.where(dv.copy().isnull(), 0.0, weights)

return masked_weights


def _validate_min_weight(min_weight: float | None) -> float:
"""Validate the ``min_weight`` value.
Parameters
----------
min_weight : float | None
Fraction of data coverage (i..e, weight) needed to return a
spatial average value. Value must range from 0 to 1.
Returns
-------
float
The required weight percentage.
Raises
------
ValueError
If the `min_weight` argument is less than 0.
ValueError
If the `min_weight` argument is greater than 1.
"""
if min_weight is None:
return 0.0
elif min_weight < 0.0:
raise ValueError(
"min_weight argument is less than 0. " "min_weight must be between 0 and 1."
)
elif min_weight > 1.0:
raise ValueError(
"min_weight argument is greater than 1. "
"min_weight must be between 0 and 1."
)

return min_weight