Skip to content

Commit

Permalink
Fix context NaN handling and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Jul 12, 2023
1 parent 09c0520 commit 094c510
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
6 changes: 4 additions & 2 deletions deepsensor/model/convnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,13 @@ def array_modify_fn(arr):

arr = arr.astype(np.float32) # Cast to float32

# Find NaNs and keep size-1 variable dim
mask = np.any(np.isnan(arr), axis=1, keepdims=True)
# Find NaNs
mask = np.isnan(arr)
if np.any(mask):
# Set NaNs to zero - necessary for `neuralprocesses` (can't have NaNs)
arr[mask] = 0.0
# Mask array (True for observed, False for missing) - keep size 1 variable dim
mask = ~np.any(mask, axis=1, keepdims=True)

# Convert to tensor object based on deep learning backend
arr = backend.convert_to_tensor(arr)
Expand Down
48 changes: 46 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs):
self.df = _gen_data_pandas()

self.dp = DataProcessor()
_ = self.dp([self.da, self.df]) # Compute normalization parameters
_ = self.dp([self.da, self.df]) # Compute normalisation parameters

def _gen_task_loader_call_args(self, n_context, n_target):
"""Generate arguments for TaskLoader.__call__
Expand Down Expand Up @@ -194,6 +194,48 @@ def test_prediction_shapes_lowlevel(self, n_target_sets):
x = B.to_numpy(model.loss_fn(task))
assert x.size == 1 and x.shape == ()

@parameterized.expand(range(1, 4))
def test_nans_offgrid_context(self, ndim):
"""Test that `ConvNP` can handle NaNs in offgrid context"""

tl = TaskLoader(
context=_gen_data_xr(data_vars=range(ndim)),
target=self.da,
)

# All NaNs
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
task["Y_c"][0][:, 0] = np.nan
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
_ = model(task)

# One NaN
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
task["Y_c"][0][0, 0] = np.nan
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
_ = model(task)

@parameterized.expand(range(1, 4))
def test_nans_gridded_context(self, ndim):
"""Test that `ConvNP` can handle NaNs in gridded context"""

tl = TaskLoader(
context=_gen_data_xr(data_vars=range(ndim)),
target=self.da,
)

# All NaNs
task = tl("2020-01-01", context_sampling="all", target_sampling=10)
task["Y_c"][0][:, 0, 0] = np.nan
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
_ = model(task)

# One NaN
task = tl("2020-01-01", context_sampling="all", target_sampling=10)
task["Y_c"][0][0, 0, 0] = np.nan
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
_ = model(task)

@parameterized.expand(range(1, 4))
def test_prediction_shapes_highlevel(self, target_dim):
"""Test high-level `.predict` interface over a range of number of target sets
Expand Down Expand Up @@ -226,7 +268,9 @@ def test_prediction_shapes_highlevel(self, target_dim):
tasks,
X_t=self.da,
n_samples=n_samples,
unnormalise=False if target_dim > 1 else True,
unnormalise=True
if target_dim == 1
else False, # TODO fix unnormalising for multiple equally named targets
)
assert [isinstance(ds, xr.Dataset) for ds in [mean_ds, std_ds, samples_ds]]
assert_shape(
Expand Down

0 comments on commit 094c510

Please sign in to comment.