Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: Consider how to test for correct weights generation in xCDAT #699

Closed
tomvothecoder opened this issue Sep 23, 2024 · 5 comments
Assignees
Labels
type: enhancement New enhancement request

Comments

@tomvothecoder
Copy link
Collaborator

tomvothecoder commented Sep 23, 2024

Is your feature request related to a problem?

In PR #689, in the TemporalAccessor._get_weights() method, I removed validation that checks the sums of weights for each time group adds up to 1 (related code).

This was done because:

  1. Assertion code should only be used in testing and debugging (link) -- TLDR: They can be turned off by the user (-O flag), not good practice. I copied this validation code from an Xarray notebook (link) without realizing the implications.
  2. I added unit tests to cover the implementation logic using the same validation code (here)
  3. np.testing.assert_allclose() degrades performance -- I expect the performance degradation to scale up in relation to the size to the data, maybe it loads data into memory with .values?

Describe the solution you'd like

What is an assertion and when to use it

assert is a common idiom in many languages. The purpose is to check conditions that are expected to always be true. A failed assertion indicates a coding bug. Therefore, the method returns nothing normally, and throws an exception if the condition fails.
...
Assertions are supposed to pass 100% of the time. An exception is only thrown when the condition unexpectedly fails.

-- https://stackoverflow.com/questions/42837054/are-assert-methods-acceptable-in-production

Option 1: Don't keep this assertion

I think the sum of weights for each group should always be 1.0 (100%) based on our implementation logic. Otherwise, the assertion would indicate a coding error on our behalf rather than bad data (although we can raise an exception if it is indeed actually due to bad time bounds). As far as I can tell, our implementation logic is right and nobody has ran into _get_weights() throwing an AssertionError from the validation code.

The _get_weights() method works by:

  1. Get time lengths by taking the difference between bounds (time_lengths)
  2. Group time lengths (grouped_time_lengths)
  3. Divide grouped time lengths by the sum of each group (grouped_time_length / grouped_time_lengths.sum())

xcdat/xcdat/temporal.py

Lines 1213 to 1262 in e9e73dd

def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
"""Calculates weights for a data variable using time bounds.
This method gets the length of time for each coordinate point by using
the difference in the upper and lower time bounds. This approach ensures
that the correct time lengths are calculated regardless of how time
coordinates are recorded (e.g., monthly, daily, hourly) and the calendar
type used.
The time lengths are labeled and grouped, then each time length is
divided by the total sum of the time lengths in its group to get its
corresponding weight.
Parameters
----------
time_bounds : xr.DataArray
The time bounds.
Returns
-------
xr.DataArray
The weights based on a specified frequency.
Notes
-----
Refer to [4]_ for the supported CF convention calendar types.
References
----------
.. [4] https://cfconventions.org/cf-conventions/cf-conventions.html#calendar
"""
with xr.set_options(keep_attrs=True):
time_lengths: xr.DataArray = time_bounds[:, 1] - time_bounds[:, 0]
# Must be cast dtype from "timedelta64[ns]" to "float64", specifically
# when using Dask arrays. Otherwise, the numpy warning below is thrown:
# `DeprecationWarning: The `dtype` and `signature` arguments to ufuncs
# only select the general DType and not details such as the byte order
# or time unit (with rare exceptions see release notes). To avoid this
# warning please use the scalar types `np.float64`, or string notation.`
if isinstance(time_lengths.data, Array):
time_lengths = time_lengths.astype("timedelta64[ns]")
time_lengths = time_lengths.astype(np.float64)
grouped_time_lengths = self._group_data(time_lengths)
weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum()
weights.name = f"{self.dim}_wts"
return weights

Option 2: Keep this assertion

Maybe the time bounds might not be correct for some reason and it produces incorrect weight for certain groups (e.g., missing data)? Not sure here. I will try experimenting with bad time bounds.

If we want to keep this assertion:

  1. Figure out how to safely use np.testing.assert_allclose() in production since it raises an AssertionError that can be turned off by the user
  2. Find a way to optimize performance for np.testing.assert_allclose()
  3. Raise a RuntimeError if numpy assertion error is raised

Describe alternatives you've considered

No response

Additional context

After this ticket is addressed, we can proceed with releasing v0.7.2.

@tomvothecoder
Copy link
Collaborator Author

@xCDAT/core-developers Any input would be appreciated here. Thanks.

@pochedls
Copy link
Collaborator

I think you can go with option 1 (drop the assertion). Even though I raised concern about removing this, when I review this code carefully, I don't think it is needed. Do you remember why it was put there in the first place?

I don't think I helped write this section of code, but I sometimes add these kind of tests to make sure NaNs aren't messing things up. I don't think that should be an issue in this instance.

@tomvothecoder
Copy link
Collaborator Author

Do you remember why it was put there in the first place?

I based my initial implementation of _get_weights() on the Xarray notebook here, which includes that validation code. It was good to have while implementing the logic in xCDAT, but I don't think it is necessary in production.

@tomvothecoder
Copy link
Collaborator Author

tomvothecoder commented Sep 25, 2024

I ran all of the grouping-based temporal averaging APIs with a dummy dataset that has bad bounds. They all worked and did not throw an AssertionError.

Weights should always correctly add up to 1.0 for each group based on how the _group_data() method works. This method groups time coordinates by their group label. The time bounds are used to calculate the weight for each specific coordinate. If we sum up weights for each group, they should still add up to 1.0. If there are bad time coordinates (e.g., one very off time coordinate), it will be its own group with a sum weight of 1.0. This is a data quality issue that the user should address beforehand.

# %%
import numpy as np
import pandas as pd
import xarray as xr
import xcdat as xc

# Define a time range with valid dates
times = pd.date_range(start="2000-01-01", periods=30, freq="ME")


# %%
# Create time bounds (lower and upper values)
time_bnds = np.array([times[:-1], times[1:]]).T
time_bnds = np.concatenate(
    (
        time_bnds,
        np.array([[times[-1], pd.Timestamp("2010-12-31")]], dtype="datetime64[ns]"),
    )
)
# Introduce incorrect time bounds
time_bnds[4, 1] = pd.Timestamp("3000-01-01")
time_bnds[8, 0] = pd.Timestamp("3000-01-01")


# %%
# Create some dummy data
data = np.random.rand(len(times))

# Create a dataset with these time coordinates and time bounds
ds = xr.Dataset(
    {"data": (["time"], data)},
    coords={"time": times, "time_bnds": (["time", "bnds"], time_bnds)},
)


# Add CF attributes
ds.attrs = {
    "Conventions": "CF-1.8",
    "title": "Sample Dataset",
    "institution": "Your Institution",
    "source": "Synthetic data",
    "history": "Created for demonstration purposes",
    "references": "http://example.com",
}

ds["data"].attrs = {
    "long_name": "Random Data",
    "units": "1",
    "standard_name": "random_data",
}

ds["time"].attrs = {
    "long_name": "Time",
    "standard_name": "time",
    "axis": "T",
    "bounds": "time_bnds",
}

ds["time_bnds"].attrs = {
    "long_name": "Time Bounds",
    "standard_name": "time_bnds",
}


# %%
ds_day = ds.temporal.group_average("data", freq="day")
ds_month = ds.temporal.group_average("data", freq="month")
ds_month = ds.temporal.group_average("data", freq="season")
ds_year = ds.temporal.group_average("data", freq="year")

# %%
ds_climo_day = ds.temporal.climatology("data", freq="day")
ds_climo_month = ds.temporal.climatology("data", freq="month")
ds_climo_season = ds.temporal.climatology("data", freq="season")

# %%
ds_depart_day = ds.temporal.departures("data", freq="day")
ds_depart_month = ds.temporal.departures("data", freq="month")
ds_depart_season = ds.temporal.departures("data", freq="season")

@tomvothecoder
Copy link
Collaborator Author

Closing this issue based on my findings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type: enhancement New enhancement request
Projects
Status: Done
Development

No branches or pull requests

2 participants