Skip to content

Commit

Permalink
example use make_realisations function (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause authored Nov 6, 2023
1 parent 3c2be6a commit 74b19c9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 37 deletions.
56 changes: 23 additions & 33 deletions examples/train_create_emus_automated.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import warnings

import xarray as xr

# import MESMER tools
from mesmer.calibrate_mesmer import train_gt, train_gv, train_lt, train_lv
from mesmer.create_emulations import (
create_emus_g,
create_emus_gv,
create_emus_l,
create_emus_lt,
create_emus_lv,
gather_gt_data,
make_realisations,
)
from mesmer.create_emulations.utils import concatenate_hist_future
from mesmer.io import load_cmipng, load_phi_gc, load_regs_ls_wgt_lon_lat
from mesmer.utils import convert_dict_to_arr, extract_land, separate_hist_future

Expand Down Expand Up @@ -71,7 +70,7 @@ def main(cfg):
print(f"{esm}")
print("=" * len(esm))

print("Calibration")
print("\nCalibration")
print("-----------")

print("- Start with global trend module")
Expand Down Expand Up @@ -139,42 +138,33 @@ def main(cfg):
{}, targs_res_lv, esm, cfg, save_params=True, aux=aux, params_lv=params_lv
)

print("Emulation")
print("\nEmulation")
print("---------")

# for this example we use the model's own smoothed gsat as predictor
gt_tas = concatenate_hist_future(gt_tas_s)

scen = list(gt_tas.keys())[0]

time_v = {}
time_v["all"] = time[esm][scen]

print("- Create global variability emulations")

preds_gv = {"time": time_v}
emus_gv_tas = create_emus_gv(params_gv_tas, preds_gv, cfg, save_emus=True)

print("- Merge the global trend and the global variability.")

create_emus_g(
gt_tas, emus_gv_tas, params_gt_tas, params_gv_tas, cfg, save_emus=True
# create a land_fraction DataArray, so we can determine the grid coordinates
land_fractions = xr.DataArray(
ls["grid_l_m"],
dims=["lat", "lon"],
coords={"lat": lat["c"], "lon": lon["c"]},
)

print("- Create local trend emulations")

emus_lt = create_emus_lt(
params_lt, preds_lt, cfg, concat_h_f=True, save_emus=True
realisations = make_realisations(
# preds_lt=gt_tas,
preds_lt=preds_lt,
params_lt=params_lt,
params_lv=params_lv,
params_gv_T=params_gv_tas,
time=time[esm],
n_realisations=cfg.nr_emus_v,
seeds=cfg.seed,
land_fractions=land_fractions,
)

print("- Create local variability emulations")

preds_lv = {"gvtas": emus_gv_tas} # predictors_list
emus_lv = create_emus_lv(params_lv, preds_lv, cfg, save_emus=True)
print("\nCreated emulations")
print("------------------")

# create and save full emulations
print("- Merge the local trends and the local variability.")
create_emus_l(emus_lt, emus_lv, params_lt, params_lv, cfg, save_emus=True)
print(realisations)


if __name__ == "__main__":
Expand Down
15 changes: 12 additions & 3 deletions mesmer/create_emulations/make_realisations.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ def __init__(self, n_realisations, seeds):
# TODO: add better checks for what happens if scenarios have different
# time axis etc.
a_scenario_key = [k for k in time.keys() if k != "hist"][0]
time_all = np.concatenate([time["hist"], time[a_scenario_key]])

if "hist" in time:
time_all = np.concatenate([time["hist"], time[a_scenario_key]])
else:
assert a_scenario_key.startswith("h-")
time_all = time[a_scenario_key]
esm_gv_T = params_gv_T["esm"]
time_seeds = seeds[esm_gv_T].keys()
preds_gv = {"time": {k: time_all for k in time_seeds}}
Expand All @@ -154,9 +159,13 @@ def _convert_raw_mesmer_to_xarray(emulations, land_fractions, time):
tmp = []
for scenario, outputs in emulations.items():
for variable, values in outputs.items():
time = np.concatenate([time["hist"], time[scenario.replace("h-", "")]])
if "hist" in time:
time_ = np.concatenate([time["hist"], time[scenario.replace("h-", "")]])
else:
time_ = time[scenario]

variable_out = (
land_fractions_stacked.expand_dims({"year": time})
land_fractions_stacked.expand_dims({"year": time_})
.expand_dims({"realisation": range(values.shape[0])})
.copy()
)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_draw_realisations_from_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_make_realisations(test_data_root_dir, update_expected_files):
ouput_dir, "test_make_realisations_expected_output.nc"
)

tseeds = {"IPSL-CM6A-LR": {"all": {"gv": 0, "lv": 1000000}}}
tseeds = {"IPSL-CM6A-LR": {"all": {"gv": 0, "lv": 1_000_000}}}

bundle_path = os.path.join(ouput_dir, "test-mesmer-bundle.pkl")

Expand Down

0 comments on commit 74b19c9

Please sign in to comment.