Skip to content

Commit

Permalink
_calibrate_tas: allow selecting predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Sep 18, 2023
1 parent 7c5fdba commit a2b4ed1
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ generated

# allow saving mesmer params without commiting them to the repo
tests/test-data/output/*/params
tests/test-data/output/*/*/params
devel/*

# output folder of examples/train_create_emus_automated.py
Expand Down
70 changes: 51 additions & 19 deletions mesmer/calibrate_mesmer/calibrate_mesmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
cross_validation_max_iterations,
save_params=False,
params_output_dir=None,
use_hfds=True,
**kwargs,
):

Expand Down Expand Up @@ -71,8 +72,11 @@ def __init__(
"lt": tas_local_trend_method,
"lv": tas_local_variability_method,
},
"hfds": {"gt": hfds_global_trend_method},
}

if use_hfds:
self.methods["hfds"] = {"gt": hfds_global_trend_method}

self.method_lt_each_gp_sep = method_lt_each_gp_sep

# Essentially metadata about predictors used. Maybe used for naming etc.
Expand All @@ -83,8 +87,11 @@ def __init__(
"gt": ["saod"],
"gv": [],
},
"hfds": {"gt": []},
}

if use_hfds:
self.preds["hfds"] = {"gt": []}

self.preds["tas"]["g_all"] = self.preds["tas"]["gt"] + self.preds["tas"]["gv"]

self.wgt_scen_tr_eq = weight_scenarios_equally
Expand Down Expand Up @@ -133,6 +140,8 @@ def _calibrate_tas(
cross_validation_max_iterations=30,
save_params=False,
params_output_dir=None,
use_tas2=True,
use_hfds=True,
**kwargs,
):
"""
Expand Down Expand Up @@ -171,6 +180,7 @@ def _calibrate_tas(
cross_validation_max_iterations=cross_validation_max_iterations,
save_params=save_params,
params_output_dir=params_output_dir,
use_hfds=use_hfds,
)

for esm in esms:
Expand All @@ -190,11 +200,14 @@ def _calibrate_tas(
# unpack data
tas_temp[scen], gsat_temp[scen], lon, lat, time[esm][scen] = out

_, ghfds_temp[scen], _, _, _ = load_cmipng("hfds", esm, scen, cfg)
if use_hfds:
_, ghfds_temp[scen], _, _, _ = load_cmipng("hfds", esm, scen, cfg)

tas_g[esm] = convert_dict_to_arr(tas_temp)
gsat[esm] = convert_dict_to_arr(gsat_temp)
ghfds[esm] = convert_dict_to_arr(ghfds_temp)

if use_hfds:
ghfds[esm] = convert_dict_to_arr(ghfds_temp)

# load in the constant files
_, ls, wgt_g, lon, lat = load_regs_ls_wgt_lon_lat(lon=lon, lat=lat)
Expand All @@ -210,10 +223,11 @@ def _calibrate_tas(
params_gt_tas = train_gt(
gsat[esm], "tas", esm, time[esm], cfg, save_params=cfg.save_params
)
# TODO: remove hard-coded hfds
params_gt_hfds = train_gt(
ghfds[esm], "hfds", esm, time[esm], cfg, save_params=cfg.save_params
)

if use_hfds:
params_gt_hfds = train_gt(
ghfds[esm], "hfds", esm, time[esm], cfg, save_params=cfg.save_params
)

# From params_gt_T, extract the global-trend so that the global variability,
# local trends, and local variability modules can be trained.
Expand All @@ -229,11 +243,16 @@ def _calibrate_tas(
LOGGER.info(
"Prepare predictors for global variability, local trends variability"
)
gt_tas2_s = {scen: gt_tas_scen**2 for scen, gt_tas_scen in gt_tas_s.items()}

gt_hfds_s = gather_gt_data(
params_gt_hfds, preds_gt, cfg, concat_h_f=False, save_emus=False
)
if use_tas2:
gt_tas2_s = {
scen: gt_tas_scen**2 for scen, gt_tas_scen in gt_tas_s.items()
}

if use_hfds:
gt_hfds_s = gather_gt_data(
params_gt_hfds, preds_gt, cfg, concat_h_f=False, save_emus=False
)

# calculate tas residuals
gv_novolc_tas = {scen: gsat[esm][scen] - gt_tas[scen] for scen in gt_tas}
Expand All @@ -248,12 +267,18 @@ def _calibrate_tas(
)

LOGGER.info("Calibrating local trends module")
preds = {
"gttas": gt_tas_s,
"gttas2": gt_tas2_s,
"gthfds": gt_hfds_s,
"gvtas": gv_novolc_tas_s,
}
preds = {}

preds["gttas"] = gt_tas_s

if use_tas2:
preds["gttas2"] = gt_tas2_s

if use_hfds:
preds["gthfds"] = gt_hfds_s

preds["gvtas"] = gv_novolc_tas_s

targs = {"tas": tas_s}
params_lt, params_lv = train_lt(
preds, targs, esm, cfg, save_params=cfg.save_params
Expand All @@ -262,7 +287,14 @@ def _calibrate_tas(
# Create forced local warming samples used for training the local variability
# module. Samples are cheap to create so not an issue to have here.
LOGGER.info("Creating local trends emulations")
preds_lt = {"gttas": gt_tas_s, "gttas2": gt_tas2_s, "gthfds": gt_hfds_s}
preds_lt = {"gttas": gt_tas_s}

if use_tas2:
preds_lt["gttas2"] = gt_tas2_s

if use_hfds:
preds_lt["gthfds"] = gt_hfds_s

lt_s = create_emus_lt(
params_lt, preds_lt, cfg, concat_h_f=False, save_emus=False
)
Expand Down
21 changes: 16 additions & 5 deletions tests/integration/test_calibrate_mesmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,24 @@

@pytest.mark.filterwarnings("ignore:No local minimum found")
@pytest.mark.parametrize(
"scenarios, outname",
"scenarios, use_tas2, use_hfds, outname",
(
[["h-ssp126"], "one_scen_one_ens"],
[["h-ssp585"], "one_scen_multi_ens"],
[["h-ssp126", "h-ssp585"], "multi_scen_multi_ens"],
[["h-ssp126"], True, True, "tas_tas2_hfds/one_scen_one_ens"],
[["h-ssp585"], True, True, "tas_tas2_hfds/one_scen_multi_ens"],
[["h-ssp126", "h-ssp585"], True, True, "tas_tas2_hfds/multi_scen_multi_ens"],
[["h-ssp126"], True, False, "tas_tas2/one_scen_one_ens"],
[["h-ssp126"], False, True, "tas_hfds/one_scen_one_ens"],
[["h-ssp126"], False, False, "tas/one_scen_one_ens"],
),
)
def test_calibrate_mesmer(
scenarios, outname, test_data_root_dir, tmpdir, update_expected_files
scenarios,
use_tas2,
use_hfds,
outname,
test_data_root_dir,
tmpdir,
update_expected_files,
):

ouput_dir = os.path.join(test_data_root_dir, "output", outname)
Expand Down Expand Up @@ -56,6 +65,8 @@ def test_calibrate_mesmer(
cmip_generation=test_cmip_generation,
observations_root_dir=test_observations_root_dir,
auxiliary_data_dir=test_auxiliary_data_dir,
use_tas2=use_tas2,
use_hfds=use_hfds,
# save params as well - they are .gitignored
save_params=update_expected_files,
params_output_dir=params_output_dir,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/test_draw_realisations_from_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

def test_make_realisations(test_data_root_dir, update_expected_files):

ouput_dir = os.path.join(test_data_root_dir, "output", "one_scen_one_ens")
ouput_dir = os.path.join(
test_data_root_dir, "output", "tas_tas2_hfds", "one_scen_one_ens"
)

expected_output_file = os.path.join(
ouput_dir, "test_make_realisations_expected_output.nc"
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit a2b4ed1

Please sign in to comment.