Skip to content

Commit

Permalink
small fixes and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Sep 18, 2023
1 parent c7e2064 commit 8fe5456
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
28 changes: 16 additions & 12 deletions mesmer/calibrate_mesmer/train_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from mesmer.utils import separate_hist_future


def train_gt(var, targ, esm, time, cfg, save_params=True):
def train_gt(data, targ, esm, time, cfg, save_params=True):
"""
Derive global trend (emissions + volcanoes) parameters from specified ensemble type
with specified method.
Parameters
----------
var : dict
data : dict
nested global mean variable dictionary with keys for each scenario employed for
training
Expand Down Expand Up @@ -75,7 +75,7 @@ def train_gt(var, targ, esm, time, cfg, save_params=True):
method_gt = cfg.methods[targ]["gt"]
preds_gt = cfg.preds[targ]["gt"]

scenarios = list(var.keys())
scenarios = list(data.keys())

# initialize param dict and fill in the metadata which does not depend on the method
params_gt = {}
Expand All @@ -90,7 +90,7 @@ def train_gt(var, targ, esm, time, cfg, save_params=True):
if "LOWESS" in params_gt["method"]:
# derive gt for each scen individually
for scen in scenarios:
gt[scen], frac_lowess = train_gt_ic_LOWESS(var[scen])
gt[scen], frac_lowess = train_gt_ic_LOWESS(data[scen])
params_gt["frac_lowess"] = frac_lowess
else:
raise ValueError("No alternative method to LOWESS is implemented for now.")
Expand All @@ -101,17 +101,21 @@ def train_gt(var, targ, esm, time, cfg, save_params=True):
gt_s, time_s = separate_hist_future(gt, time, cfg)

# compute median LOWESS estimate of historical part across all scenarios
gt_lowess_hist_all_new = gt_s.pop("hist")
gt_hist_all = gt_s.pop("hist")

gt_hist_median = np.median(gt_lowess_hist_all_new, axis=0)
gt_hist_median = np.median(gt_hist_all, axis=0)

if params_gt["method"] == "LOWESS_OLSVOLC":
var_s, time_s = separate_hist_future(var, time, cfg)
data_s, time_s = separate_hist_future(data, time, cfg)

params_gt["saod"], params_gt["hist"] = train_gt_ic_OLSVOLC(
var_s["hist"], gt_hist_median, time_s["hist"]
# estimate volcanic influence and add to smooth time series
coef_saod, gt_hist_olsvolc = train_gt_ic_OLSVOLC(
data_s["hist"], gt_hist_median, time_s["hist"]
)

params_gt["saod"] = coef_saod
params_gt["hist"] = gt_hist_olsvolc

elif params_gt["method"] == "LOWESS":
params_gt["hist"] = gt_hist_median

Expand Down Expand Up @@ -226,16 +230,16 @@ def train_gt_ic_OLSVOLC(var, gt_lowess, time, cfg=None):
# drop "year" coords - aod_obs does not have coords (currently)
aod_obs = aod_obs.drop_vars("year")

# repeat aod time series as many times as runs available
aod_obs_all = xr.concat([aod_obs] * nr_runs, dim="year")

nr_aod_obs = aod_obs.shape[0]
if nr_ts != nr_aod_obs:
raise ValueError(
f"The number of time steps of the variable ({nr_ts}) and the saod "
f"({nr_aod_obs}) do not match."
)

# repeat aod time series as many times as runs available
aod_obs_all = xr.concat([aod_obs] * nr_runs, dim="year")

# extract global variability (which still includes volc eruptions) by removing
# smooth trend from Tglob in historic period
# (should broadcast, and flatten the correct way - hopefully)
Expand Down
2 changes: 1 addition & 1 deletion mesmer/create_emulations/create_emus_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def gather_gt_data(params_gt, preds_gt, cfg, concat_h_f=False, save_emus=True):

emus_gt = {}

# apply the chosen method
# gather data
if "LOWESS" in params_gt["method"]:
for scen in scenarios_emus:
emus_gt[scen] = params_gt[scen]
Expand Down
3 changes: 1 addition & 2 deletions mesmer/io/load_constant_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def load_regs_ls_wgt_lon_lat(reg_type=None, lon=None, lat=None):
Parameters
----------
reg_type : str, optional, default: None
Deprecated. No longer has an effect, if None is passed this
function will only return four parameters.
Deprecated, no longer has an effect.
lon : dict
longitude dictionary with key
Expand Down
8 changes: 3 additions & 5 deletions mesmer/utils/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def extract_land(var, reg_dict=None, wgt=None, ls=None, threshold_land=0.25):
- [esm][scen] (4d array (run, time, lat, lon) of variable)
reg_dict : dict | None
Deprecated. No longer has an effect (except changing the number of output
params).
Deprecated. No longer has an effect.
wgt : np.ndarray
2d array (lat, lon) of weights to be used for area weighted means
Expand All @@ -49,8 +48,7 @@ def extract_land(var, reg_dict=None, wgt=None, ls=None, threshold_land=0.25):
- [esm] (3d array (run, time, gp_l) of variable at land grid points)
reg_dict : dict
Optional output (empty dict). Only returned when the input ``reg_dict`` is not
``None``.
Deprecated (empty dict).
ls : dict
land sea dictionary with added keys
Expand Down Expand Up @@ -131,7 +129,7 @@ def extract_time_period(data, time, start, end):
"""

warnings.warn(
"`extract_time_period` is deprecated in v0.9.0 and will be remove in a future "
"`extract_time_period` is deprecated in v0.9.0 and will be removed in a future "
"version. Please raise an issue if you still use this function.",
FutureWarning,
)
Expand Down

0 comments on commit 8fe5456

Please sign in to comment.