Skip to content

Commit

Permalink
Cleanup nudging stepper
Browse files Browse the repository at this point in the history
Refactor nudging logic to a more "pure" class.
  • Loading branch information
nbren12 committed Feb 23, 2021
1 parent 419a452 commit 9ed918f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 68 deletions.
2 changes: 1 addition & 1 deletion workflows/prognostic_c48_run/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
nudging_timescales_from_dict,
setup_get_reference_state,
get_nudging_tendency,
set_state_sst_to_reference,
get_reference_surface_temperatures,
)

__version__ = "0.1.0"
1 change: 0 additions & 1 deletion workflows/prognostic_c48_run/runtime/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def _get_stepper(self, config: UserConfig) -> Stepper:
self._comm.rank,
config.nudging,
timestep=self._timestep,
states_to_output=self._states_to_output,
communicator=communicator,
)
else:
Expand Down
24 changes: 15 additions & 9 deletions workflows/prognostic_c48_run/runtime/nudging.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,25 @@ def get_nudging_tendency(
return return_dict


def set_state_sst_to_reference(state: State, reference: State) -> State:
def get_reference_surface_temperatures(state: State, reference: State) -> State:
"""
Set the sea surface and surface temperatures in a model state to values in
a reference state. Useful for maintaining consistency between a nudged run
and reference state.
"""
state[SST_NAME] = _sst_from_reference(
reference[TSFC_NAME], state[SST_NAME], state[MASK_NAME]
)
state[TSFC_NAME] = _sst_from_reference(
reference[TSFC_NAME], state[TSFC_NAME], state[MASK_NAME]
)
return state
state = {
SST_NAME: _sst_from_reference(
reference[TSFC_NAME], state[SST_NAME], state[MASK_NAME]
),
TSFC_NAME: _sst_from_reference(
reference[TSFC_NAME], state[TSFC_NAME], state[MASK_NAME]
),
}
# TODO fix this bug in a follow-up non-refactor PR
# this logic replicates a bug in the previous nudged run
# the SSTs were never actually applied to the fv3gfs-wrapper state
# This bug could be scientifically significant.
return {}


def _sst_from_reference(
Expand All @@ -260,4 +266,4 @@ def _sst_from_reference(
land_sea_mask.values.round().astype("int") == 0,
reference_surface_temperature,
surface_temperature,
)
).assign_attrs(units=reference_surface_temperature.units)
105 changes: 48 additions & 57 deletions workflows/prognostic_c48_run/runtime/steppers/nudging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import functools
from typing import (
Any,
List,
Sequence,
)
from typing import Any

import fv3gfs.util
import fv3gfs.wrapper

from runtime.steppers.base import (
Stepper,
Expand All @@ -22,91 +19,85 @@
nudging_timescales_from_dict,
setup_get_reference_state,
get_nudging_tendency,
set_state_sst_to_reference,
get_reference_surface_temperatures,
NudgingConfig,
)

from runtime.names import (
DELP,
TOTAL_PRECIP,
PRECIP_RATE,
)
from runtime.names import TOTAL_PRECIP


SST_NAME = "ocean_surface_temperature"
TSFC_NAME = "surface_temperature"
MASK_NAME = "land_sea_mask"


class NudgingStepper(Stepper, LoggingMixin):
"""Stepper for nudging
"""
class PureNudger:

name = "nudging"

def __init__(
self,
state,
fv3gfs: Any,
rank: int,
config: NudgingConfig,
timestep: float,
states_to_output: Sequence[str],
communicator: fv3gfs.util.CubedSphereCommunicator,
self, config: NudgingConfig, communicator: fv3gfs.util.CubedSphereCommunicator,
):

self._states_to_output = states_to_output
self._state = state

self._fv3gfs = fv3gfs
self.rank: int = rank
self._timestep: float = timestep

self._nudging_timescales = nudging_timescales_from_dict(config.timescale_hours)
variables_to_nudge = list(config.timescale_hours)
self._get_reference_state = setup_get_reference_state(
config,
self.nudging_variables + [SST_NAME, TSFC_NAME],
self._fv3gfs.get_tracer_metadata(),
variables_to_nudge + [SST_NAME, TSFC_NAME],
fv3gfs.wrapper.get_tracer_metadata(),
communicator,
)

self._nudging_timescales = nudging_timescales_from_dict(config.timescale_hours)
self._get_nudging_tendency = functools.partial(
get_nudging_tendency, nudging_timescales=self._nudging_timescales,
)
self._tendencies_to_apply_to_dycore_state: State = {}

@property
def nudging_variables(self) -> List[str]:
return list(self._nudging_timescales)

def _compute_python_tendency(self) -> Diagnostics:

self._log_debug("Computing nudging tendencies")
variables: List[str] = self.nudging_variables + [
SST_NAME,
TSFC_NAME,
MASK_NAME,
]
state: State = {name: self._state[name] for name in variables}
reference = self._get_reference_state(self._state.time)
set_state_sst_to_reference(state, reference)
self._tendencies_to_apply_to_dycore_state = self._get_nudging_tendency(
state, reference
)
def __call__(self, time, state):
reference = self._get_reference_state(time)
tendencies = get_nudging_tendency(state, reference, self._nudging_timescales)
ssts = get_reference_surface_temperatures(state, reference)

return {
reference = {
f"{key}_reference": reference_state
for key, reference_state in reference.items()
}
return tendencies, ssts, reference

def _apply_python_to_dycore_state(self) -> Diagnostics:

tendency = self._tendencies_to_apply_to_dycore_state
class NudgingStepper(Stepper, LoggingMixin):
"""Stepper for nudging"""

def __init__(
self,
state,
fv3gfs: Any,
rank: int,
config: NudgingConfig,
timestep: float,
communicator: fv3gfs.util.CubedSphereCommunicator,
):
self._state = state
self._timestep: float = timestep
self.nudger = PureNudger(config, communicator)
self._tendencies: State = {}
self._state_updates: State = {}

def _compute_python_tendency(self) -> Diagnostics:
(self._tendencies, self._state_updates, diagnostics,) = self.nudger(
self._state.time, self._state
)
return diagnostics

def _apply_python_to_dycore_state(self) -> Diagnostics:

diagnostics = compute_nudging_diagnostics(self._state, tendency)
updated_state: State = apply(self._state, tendency, dt=self._timestep)
diagnostics = compute_nudging_diagnostics(self._state, self._tendencies)
updated_state: State = apply(self._state, self._tendencies, dt=self._timestep)
updated_state[TOTAL_PRECIP] = precipitation_sum(
self._state[TOTAL_PRECIP],
diagnostics["net_moistening_due_to_nudging"],
diagnostics[f"net_moistening_due_to_{self.nudger.name}"],
self._timestep,
)
diagnostics[TOTAL_PRECIP] = updated_state[TOTAL_PRECIP]
self._state.update(updated_state)
self._state.update(self._state_updates)
return diagnostics

0 comments on commit 9ed918f

Please sign in to comment.