From 74b19c90fbda54639f21d3d7268251412f37d4db Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 6 Nov 2023 15:25:00 +0100 Subject: [PATCH] example use make_realisations function (#325) --- examples/train_create_emus_automated.py | 56 ++++++++----------- mesmer/create_emulations/make_realisations.py | 15 ++++- .../test_draw_realisations_from_bundle.py | 2 +- 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/examples/train_create_emus_automated.py b/examples/train_create_emus_automated.py index 894be9d5..44ac8607 100644 --- a/examples/train_create_emus_automated.py +++ b/examples/train_create_emus_automated.py @@ -1,16 +1,15 @@ import warnings +import xarray as xr + # import MESMER tools from mesmer.calibrate_mesmer import train_gt, train_gv, train_lt, train_lv from mesmer.create_emulations import ( - create_emus_g, - create_emus_gv, - create_emus_l, create_emus_lt, create_emus_lv, gather_gt_data, + make_realisations, ) -from mesmer.create_emulations.utils import concatenate_hist_future from mesmer.io import load_cmipng, load_phi_gc, load_regs_ls_wgt_lon_lat from mesmer.utils import convert_dict_to_arr, extract_land, separate_hist_future @@ -71,7 +70,7 @@ def main(cfg): print(f"{esm}") print("=" * len(esm)) - print("Calibration") + print("\nCalibration") print("-----------") print("- Start with global trend module") @@ -139,42 +138,33 @@ def main(cfg): {}, targs_res_lv, esm, cfg, save_params=True, aux=aux, params_lv=params_lv ) - print("Emulation") + print("\nEmulation") print("---------") - # for this example we use the model's own smoothed gsat as predictor - gt_tas = concatenate_hist_future(gt_tas_s) - - scen = list(gt_tas.keys())[0] - - time_v = {} - time_v["all"] = time[esm][scen] - - print("- Create global variability emulations") - preds_gv = {"time": time_v} - emus_gv_tas = create_emus_gv(params_gv_tas, preds_gv, cfg, save_emus=True) - - print("- Merge the global trend and the global variability.") - - create_emus_g( - gt_tas, emus_gv_tas, params_gt_tas, params_gv_tas, cfg, save_emus=True + # create a land_fraction DataArray, so we can determine the grid coordinates + land_fractions = xr.DataArray( + ls["grid_l_m"], + dims=["lat", "lon"], + coords={"lat": lat["c"], "lon": lon["c"]}, ) - print("- Create local trend emulations") - - emus_lt = create_emus_lt( - params_lt, preds_lt, cfg, concat_h_f=True, save_emus=True + realisations = make_realisations( + # preds_lt=gt_tas, + preds_lt=preds_lt, + params_lt=params_lt, + params_lv=params_lv, + params_gv_T=params_gv_tas, + time=time[esm], + n_realisations=cfg.nr_emus_v, + seeds=cfg.seed, + land_fractions=land_fractions, ) - print("- Create local variability emulations") - - preds_lv = {"gvtas": emus_gv_tas} # predictors_list - emus_lv = create_emus_lv(params_lv, preds_lv, cfg, save_emus=True) + print("\nCreated emulations") + print("------------------") - # create and save full emulations - print("- Merge the local trends and the local variability.") - create_emus_l(emus_lt, emus_lv, params_lt, params_lv, cfg, save_emus=True) + print(realisations) if __name__ == "__main__": diff --git a/mesmer/create_emulations/make_realisations.py b/mesmer/create_emulations/make_realisations.py index 0403ddf0..9520b65d 100644 --- a/mesmer/create_emulations/make_realisations.py +++ b/mesmer/create_emulations/make_realisations.py @@ -130,7 +130,12 @@ def __init__(self, n_realisations, seeds): # TODO: add better checks for what happens if scenarios have different # time axis etc. a_scenario_key = [k for k in time.keys() if k != "hist"][0] - time_all = np.concatenate([time["hist"], time[a_scenario_key]]) + + if "hist" in time: + time_all = np.concatenate([time["hist"], time[a_scenario_key]]) + else: + assert a_scenario_key.startswith("h-") + time_all = time[a_scenario_key] esm_gv_T = params_gv_T["esm"] time_seeds = seeds[esm_gv_T].keys() preds_gv = {"time": {k: time_all for k in time_seeds}} @@ -154,9 +159,13 @@ def _convert_raw_mesmer_to_xarray(emulations, land_fractions, time): tmp = [] for scenario, outputs in emulations.items(): for variable, values in outputs.items(): - time = np.concatenate([time["hist"], time[scenario.replace("h-", "")]]) + if "hist" in time: + time_ = np.concatenate([time["hist"], time[scenario.replace("h-", "")]]) + else: + time_ = time[scenario] + variable_out = ( - land_fractions_stacked.expand_dims({"year": time}) + land_fractions_stacked.expand_dims({"year": time_}) .expand_dims({"realisation": range(values.shape[0])}) .copy() ) diff --git a/tests/integration/test_draw_realisations_from_bundle.py b/tests/integration/test_draw_realisations_from_bundle.py index 95f8eb36..1eb352a4 100644 --- a/tests/integration/test_draw_realisations_from_bundle.py +++ b/tests/integration/test_draw_realisations_from_bundle.py @@ -21,7 +21,7 @@ def test_make_realisations(test_data_root_dir, update_expected_files): ouput_dir, "test_make_realisations_expected_output.nc" ) - tseeds = {"IPSL-CM6A-LR": {"all": {"gv": 0, "lv": 1000000}}} + tseeds = {"IPSL-CM6A-LR": {"all": {"gv": 0, "lv": 1_000_000}}} bundle_path = os.path.join(ouput_dir, "test-mesmer-bundle.pkl")