Skip to content

Commit

Permalink
Fix coord rounding errors in model.predict return objects by circum…
Browse files Browse the repository at this point in the history
…venting `DataProcessor` norm-unnorm operation on coords
  • Loading branch information
tom-andersson committed Jul 17, 2023
1 parent bbe9968 commit 74cc5b8
Showing 1 changed file with 99 additions and 56 deletions.
155 changes: 99 additions & 56 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def predict(
X_t: Union[
xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index, np.ndarray
],
X_t_normalised: bool = False,
X_t_is_normalised: bool = False,
resolution_factor=1,
n_samples=0,
ar_sample=False,
Expand All @@ -210,24 +210,25 @@ def predict(
TODO:
- Test with multiple targets model
:param tasks: List of tasks containing context data.
:param X_t: Target locations to predict at. Can be an xarray object containing
on-grid locations or a pandas object containing off-grid locations.
:param X_t_normalised: Whether the `X_t` coords are normalised.
If False, will normalise the coords before passing to model. Default False.
:param resolution_factor: Optional factor to increase the resolution of the
target grid by. E.g. 2 will double the target resolution, 0.5 will halve it.
Applies to on-grid predictions only. Default 1.
:param n_samples: Number of joint samples to draw from the model.
If 0, will not draw samples. Default 0.
:param ar_sample: Whether to use autoregressive sampling. Default False.
:param unnormalise: Whether to unnormalise the predictions. Only works if
`self` has a `data_processor` and `task_loader` attribute. Default True.
:param seed: Random seed for deterministic sampling. Default 0.
:param append_indexes: Dictionary of index metadata to append to pandas indexes
in the off-grid case. Default None.
:param progress_bar: Whether to display a progress bar over tasks. Default 0.
:param verbose: Whether to print time taken for prediction. Default False.
Args:
tasks: List of tasks containing context data.
X_t: Target locations to predict at. Can be an xarray object containing
on-grid locations or a pandas object containing off-grid locations.
X_t_is_normalised: Whether the `X_t` coords are normalised.
If False, will normalise the coords before passing to model. Default False.
resolution_factor: Optional factor to increase the resolution of the
target grid by. E.g. 2 will double the target resolution, 0.5 will halve it.
Applies to on-grid predictions only. Default 1.
n_samples: Number of joint samples to draw from the model.
If 0, will not draw samples. Default 0.
ar_sample: Whether to use autoregressive sampling. Default False.
unnormalise: Whether to unnormalise the predictions. Only works if
`self` has a `data_processor` and `task_loader` attribute. Default True.
seed: Random seed for deterministic sampling. Default 0.
append_indexes: Dictionary of index metadata to append to pandas indexes
in the off-grid case. Default None.
progress_bar: Whether to display a progress bar over tasks. Default 0.
verbose: Whether to print time taken for prediction. Default False.
Returns:
- If X_t is a pandas object, returns pandas objects containing off-grid predictions.
Expand All @@ -242,12 +243,25 @@ def predict(
raise ValueError(
"resolution_factor can only be used with on-grid predictions."
)
if ar_subsample_factor != 1:
raise ValueError(
"ar_subsample_factor can only be used with on-grid predictions."
)
if not isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
if append_indexes is not None:
raise ValueError(
"append_indexes can only be used with off-grid predictions."
)

if isinstance(X_t, (xr.DataArray, xr.Dataset)):
mode = "on-grid"
elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
mode = "off-grid"
else:
raise ValueError(
f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}."
)

if type(tasks) is Task:
tasks = [tasks]

Expand All @@ -262,59 +276,78 @@ def predict(
var_ID for set in self.task_loader.target_var_IDs for var_ID in set
]

# Pre-process X_t if necessary
if isinstance(X_t, pd.Index):
X_t = pd.DataFrame(index=X_t)
elif isinstance(X_t, np.ndarray):
# Convert to empty dataframe with normalised or unnormalised coord names
if X_t_normalised:
if X_t_is_normalised:
index_names = ["x1", "x2"]
else:
index_names = self.data_processor.raw_spatial_coord_names
X_t = pd.DataFrame(X_t.T, columns=index_names)
X_t = X_t.set_index(index_names)
if mode == "off-grid" and append_indexes is not None:
# Check append_indexes are all same length as X_t
if append_indexes is not None:
for idx, vals in append_indexes.items():
if len(vals) != len(X_t):
raise ValueError(
f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively."
)
X_t = X_t.reset_index()
X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1)
X_t = X_t.set_index(list(X_t.columns))

if not X_t_normalised:
X_t = self.data_processor.map_coords(X_t) # Normalise
if X_t_is_normalised:
X_t_normalised = X_t

if isinstance(X_t, (xr.DataArray, xr.Dataset)):
mode = "on-grid"
elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)):
mode = "off-grid"
if append_indexes is not None:
# Check append_indexes are all same length as X_t
if append_indexes is not None:
for idx, vals in append_indexes.items():
if len(vals) != len(X_t):
raise ValueError(
f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively."
)
X_t = X_t.reset_index()
X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1)
X_t = X_t.set_index(list(X_t.columns))
# Unnormalise coords to use for xarray/pandas objects for storing predictions
X_t = self.data_processor.map_coords(X_t, unnorm=True)
else:
raise ValueError(
f"X_t must be an xarray object or a pandas object, not {type(X_t)}"
)
# Normalise coords to use for model
X_t_normalised = self.data_processor.map_coords(X_t)

if mode == "on-grid":
X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values)
elif mode == "off-grid":
X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T

if not unnormalise:
X_t = X_t_normalised
coord_names = {"x1": "x1", "x2": "x2"}
elif unnormalise:
coord_names = {
"x1": self.data_processor.raw_spatial_coord_names[0],
"x2": self.data_processor.raw_spatial_coord_names[1],
}

# Create empty xarray/pandas objects to store predictions
if mode == "on-grid":
mean = create_empty_spatiotemporal_xarray(
X_t, dates, resolution_factor, data_vars=target_var_IDs
X_t,
dates,
resolution_factor,
data_vars=target_var_IDs,
coord_names=coord_names,
).to_array(dim="data_var")
std = create_empty_spatiotemporal_xarray(
X_t, dates, resolution_factor, data_vars=target_var_IDs
X_t,
dates,
resolution_factor,
data_vars=target_var_IDs,
coord_names=coord_names,
).to_array(dim="data_var")
if n_samples >= 1:
samples = create_empty_spatiotemporal_xarray(
X_t,
dates,
resolution_factor,
data_vars=target_var_IDs,
coord_names=coord_names,
prepend_dims=["sample"],
prepend_coords={"sample": np.arange(n_samples)},
).to_array(dim="data_var")

X_t_arr = (mean["x1"].values, mean["x2"].values)

elif mode == "off-grid":
# Repeat target locs for each date to create multiindex
idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
Expand All @@ -333,7 +366,18 @@ def predict(
)
samples = pd.DataFrame(index=index_samples, columns=target_var_IDs)

X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T
def unnormalise_pred_array(arr, **kwargs):
var_IDs_flattened = [
var_ID
for var_IDs in self.task_loader.target_var_IDs
for var_ID in var_IDs
]
assert arr.shape[0] == len(var_IDs_flattened)
for i, var_ID in enumerate(var_IDs_flattened):
arr[i] = self.data_processor.map_array(
arr[i], var_ID, method="mean_std", unnorm=True, **kwargs
)
return arr

# Don't change tasks by reference when overriding target locations
tasks = copy.deepcopy(tasks)
Expand Down Expand Up @@ -385,6 +429,15 @@ def predict(
if n_samples >= 1:
samples_arr = np.concatenate(samples_arr, axis=0)

if unnormalise:
mean_arr = unnormalise_pred_array(mean_arr)
std_arr = unnormalise_pred_array(std_arr, add_offset=False)
if n_samples >= 1:
for sample_i in range(n_samples):
samples_arr[sample_i] = unnormalise_pred_array(
samples_arr[sample_i]
)

if mode == "on-grid":
mean.loc[:, task["time"], :, :] = mean_arr
std.loc[:, task["time"], :, :] = std_arr
Expand All @@ -407,16 +460,6 @@ def predict(
if n_samples >= 1:
samples = samples.to_dataset(dim="data_var")

if (
self.task_loader is not None
and self.data_processor is not None
and unnormalise == True
):
mean = self.data_processor.unnormalise(mean)
std = self.data_processor.unnormalise(std, add_offset=False)
if n_samples >= 1:
samples = self.data_processor.unnormalise(samples)

if verbose:
dur = time.time() - tic
print(f"Done in {np.floor(dur / 60)}m:{dur % 60:.0f}s.\n")
Expand Down

0 comments on commit 74cc5b8

Please sign in to comment.