From 74cc5b863c266229ddc9c58589052beb20bf7872 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Mon, 17 Jul 2023 16:45:58 +0100 Subject: [PATCH] Fix coord rounding errors in `model.predict` return objects by circumventing `DataProcessor` norm-unnorm operation on coords --- deepsensor/model/model.py | 155 ++++++++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 56 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 7425ed91..71ace7dc 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -194,7 +194,7 @@ def predict( X_t: Union[ xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index, np.ndarray ], - X_t_normalised: bool = False, + X_t_is_normalised: bool = False, resolution_factor=1, n_samples=0, ar_sample=False, @@ -210,24 +210,25 @@ def predict( TODO: - Test with multiple targets model - :param tasks: List of tasks containing context data. - :param X_t: Target locations to predict at. Can be an xarray object containing - on-grid locations or a pandas object containing off-grid locations. - :param X_t_normalised: Whether the `X_t` coords are normalised. - If False, will normalise the coords before passing to model. Default False. - :param resolution_factor: Optional factor to increase the resolution of the - target grid by. E.g. 2 will double the target resolution, 0.5 will halve it. - Applies to on-grid predictions only. Default 1. - :param n_samples: Number of joint samples to draw from the model. - If 0, will not draw samples. Default 0. - :param ar_sample: Whether to use autoregressive sampling. Default False. - :param unnormalise: Whether to unnormalise the predictions. Only works if - `self` has a `data_processor` and `task_loader` attribute. Default True. - :param seed: Random seed for deterministic sampling. Default 0. - :param append_indexes: Dictionary of index metadata to append to pandas indexes - in the off-grid case. Default None. - :param progress_bar: Whether to display a progress bar over tasks. Default 0. - :param verbose: Whether to print time taken for prediction. Default False. + Args: + tasks: List of tasks containing context data. + X_t: Target locations to predict at. Can be an xarray object containing + on-grid locations or a pandas object containing off-grid locations. + X_t_is_normalised: Whether the `X_t` coords are normalised. + If False, will normalise the coords before passing to model. Default False. + resolution_factor: Optional factor to increase the resolution of the + target grid by. E.g. 2 will double the target resolution, 0.5 will halve it. + Applies to on-grid predictions only. Default 1. + n_samples: Number of joint samples to draw from the model. + If 0, will not draw samples. Default 0. + ar_sample: Whether to use autoregressive sampling. Default False. + unnormalise: Whether to unnormalise the predictions. Only works if + `self` has a `data_processor` and `task_loader` attribute. Default True. + seed: Random seed for deterministic sampling. Default 0. + append_indexes: Dictionary of index metadata to append to pandas indexes + in the off-grid case. Default None. + progress_bar: Whether to display a progress bar over tasks. Default 0. + verbose: Whether to print time taken for prediction. Default False. Returns: - If X_t is a pandas object, returns pandas objects containing off-grid predictions. @@ -242,12 +243,25 @@ def predict( raise ValueError( "resolution_factor can only be used with on-grid predictions." ) + if ar_subsample_factor != 1: + raise ValueError( + "ar_subsample_factor can only be used with on-grid predictions." + ) if not isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)): if append_indexes is not None: raise ValueError( "append_indexes can only be used with off-grid predictions." ) + if isinstance(X_t, (xr.DataArray, xr.Dataset)): + mode = "on-grid" + elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)): + mode = "off-grid" + else: + raise ValueError( + f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}." + ) + if type(tasks) is Task: tasks = [tasks] @@ -262,46 +276,67 @@ def predict( var_ID for set in self.task_loader.target_var_IDs for var_ID in set ] + # Pre-process X_t if necessary if isinstance(X_t, pd.Index): X_t = pd.DataFrame(index=X_t) elif isinstance(X_t, np.ndarray): # Convert to empty dataframe with normalised or unnormalised coord names - if X_t_normalised: + if X_t_is_normalised: index_names = ["x1", "x2"] else: index_names = self.data_processor.raw_spatial_coord_names X_t = pd.DataFrame(X_t.T, columns=index_names) X_t = X_t.set_index(index_names) + if mode == "off-grid" and append_indexes is not None: + # Check append_indexes are all same length as X_t + if append_indexes is not None: + for idx, vals in append_indexes.items(): + if len(vals) != len(X_t): + raise ValueError( + f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively." + ) + X_t = X_t.reset_index() + X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1) + X_t = X_t.set_index(list(X_t.columns)) - if not X_t_normalised: - X_t = self.data_processor.map_coords(X_t) # Normalise + if X_t_is_normalised: + X_t_normalised = X_t - if isinstance(X_t, (xr.DataArray, xr.Dataset)): - mode = "on-grid" - elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)): - mode = "off-grid" - if append_indexes is not None: - # Check append_indexes are all same length as X_t - if append_indexes is not None: - for idx, vals in append_indexes.items(): - if len(vals) != len(X_t): - raise ValueError( - f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively." - ) - X_t = X_t.reset_index() - X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1) - X_t = X_t.set_index(list(X_t.columns)) + # Unnormalise coords to use for xarray/pandas objects for storing predictions + X_t = self.data_processor.map_coords(X_t, unnorm=True) else: - raise ValueError( - f"X_t must be an xarray object or a pandas object, not {type(X_t)}" - ) + # Normalise coords to use for model + X_t_normalised = self.data_processor.map_coords(X_t) + if mode == "on-grid": + X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values) + elif mode == "off-grid": + X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T + + if not unnormalise: + X_t = X_t_normalised + coord_names = {"x1": "x1", "x2": "x2"} + elif unnormalise: + coord_names = { + "x1": self.data_processor.raw_spatial_coord_names[0], + "x2": self.data_processor.raw_spatial_coord_names[1], + } + + # Create empty xarray/pandas objects to store predictions if mode == "on-grid": mean = create_empty_spatiotemporal_xarray( - X_t, dates, resolution_factor, data_vars=target_var_IDs + X_t, + dates, + resolution_factor, + data_vars=target_var_IDs, + coord_names=coord_names, ).to_array(dim="data_var") std = create_empty_spatiotemporal_xarray( - X_t, dates, resolution_factor, data_vars=target_var_IDs + X_t, + dates, + resolution_factor, + data_vars=target_var_IDs, + coord_names=coord_names, ).to_array(dim="data_var") if n_samples >= 1: samples = create_empty_spatiotemporal_xarray( @@ -309,12 +344,10 @@ def predict( dates, resolution_factor, data_vars=target_var_IDs, + coord_names=coord_names, prepend_dims=["sample"], prepend_coords={"sample": np.arange(n_samples)}, ).to_array(dim="data_var") - - X_t_arr = (mean["x1"].values, mean["x2"].values) - elif mode == "off-grid": # Repeat target locs for each date to create multiindex idxs = [(date, *idxs) for date in dates for idxs in X_t.index] @@ -333,7 +366,18 @@ def predict( ) samples = pd.DataFrame(index=index_samples, columns=target_var_IDs) - X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T + def unnormalise_pred_array(arr, **kwargs): + var_IDs_flattened = [ + var_ID + for var_IDs in self.task_loader.target_var_IDs + for var_ID in var_IDs + ] + assert arr.shape[0] == len(var_IDs_flattened) + for i, var_ID in enumerate(var_IDs_flattened): + arr[i] = self.data_processor.map_array( + arr[i], var_ID, method="mean_std", unnorm=True, **kwargs + ) + return arr # Don't change tasks by reference when overriding target locations tasks = copy.deepcopy(tasks) @@ -385,6 +429,15 @@ def predict( if n_samples >= 1: samples_arr = np.concatenate(samples_arr, axis=0) + if unnormalise: + mean_arr = unnormalise_pred_array(mean_arr) + std_arr = unnormalise_pred_array(std_arr, add_offset=False) + if n_samples >= 1: + for sample_i in range(n_samples): + samples_arr[sample_i] = unnormalise_pred_array( + samples_arr[sample_i] + ) + if mode == "on-grid": mean.loc[:, task["time"], :, :] = mean_arr std.loc[:, task["time"], :, :] = std_arr @@ -407,16 +460,6 @@ def predict( if n_samples >= 1: samples = samples.to_dataset(dim="data_var") - if ( - self.task_loader is not None - and self.data_processor is not None - and unnormalise == True - ): - mean = self.data_processor.unnormalise(mean) - std = self.data_processor.unnormalise(std, add_offset=False) - if n_samples >= 1: - samples = self.data_processor.unnormalise(samples) - if verbose: dur = time.time() - tic print(f"Done in {np.floor(dur / 60)}m:{dur % 60:.0f}s.\n")