Skip to content

Commit

Permalink
Add policy control and auto-detect. NOTE: for now only load, not load…
Browse files Browse the repository at this point in the history
…_cubes/load_cube
  • Loading branch information
pp-mo committed Oct 11, 2024
1 parent a20b7bd commit 23f35d4
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 5 deletions.
126 changes: 124 additions & 2 deletions lib/iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,17 @@ def _generate_cubes(uris, callback, constraints):

def _load_collection(uris, constraints=None, callback=None):
from iris.cube import _CubeFilterCollection
from iris.fileformats.rules import _MULTIREF_DETECTION

try:
# This routine is called once per iris load operation.
# Control of the "multiple refs" handling is implicit in this routine
# NOTE: detection of multiple reference fields, and it's enabling of post-load
# concatenation, is triggered **per-load, not per-cube**
# This behaves unexpectefly for "iris.load_cubes" : a post-concatenation is
# triggered for all cubes or none, not per-cube (i.e. per constraint).
_MULTIREF_DETECTION.found_multiple_refs = False

cubes = _generate_cubes(uris, callback, constraints)
result = _CubeFilterCollection.from_cubes(cubes, constraints)
except EOFError as e:
Expand All @@ -303,7 +312,118 @@ def _load_collection(uris, constraints=None, callback=None):
return result


def load(uris, constraints=None, callback=None):
class LoadPolicy(threading.local):
"""Object defining a general loading strategy."""

_allkeys = (
"support_multiple_references",
"multiref_triggers_concatenate",
"use_concatenate",
"use_merge",
"cat_before_merge",
"repeat_until_done",
)

def __init__(
self,
support_multiple_references: bool = False,
multiref_triggers_concatenate: bool = False,
use_concatenate: bool = False,
use_merge: bool = True,
cat_before_merge: bool = False,
repeat_until_done: bool = False,
):
"""Container for loading controls."""
self.support_multiple_references = support_multiple_references
self.multiref_triggers_concatenate = multiref_triggers_concatenate
self.use_concatenate = use_concatenate
self.use_merge = use_merge
self.cat_before_merge = cat_before_merge
self.repeat_until_done = repeat_until_done

def __repr__(self):
msg = (
"LoadPolicy("
f"support_multiple_references={self.support_multiple_references}, "
f"multiref_triggers_concatenate={self.multiref_triggers_concatenate}, "
f"use_concatenate={self.use_concatenate}, "
f"use_merge={self.use_merge}, "
f"cat_before_merge={self.cat_before_merge}, "
f"repeat_until_done={self.repeat_until_done}"
")"
)
return msg

def copy(self):
return LoadPolicy(**{key: getattr(self, key) for key in self._allkeys})

@contextlib.contextmanager
def context(self, policy=None, **kwargs):
"""Return context manager for temporary options.
Modifies the given parameters within a context, for the active thread.
"""
# Save the current statr
current_state = self.__dict__.copy()

# Update the state from given policy object and/or method keywords
for name in self._allkeys:
value = getattr(self, name)
if policy and hasattr(policy, name):
value = getattr(policy, name)
if name in kwargs:
value = kwargs[name]
setattr(self, name, value)

try:
# Execute the context
yield
finally:
# Return the state
self.__dict__.clear()
self.__dict__.update(current_state)


LOAD_POLICY = LoadPolicy()
LOAD_POLICY_LEGACY = LoadPolicy()
LOAD_POLICY_RECOMMENDED = LoadPolicy(
support_multiple_references=True, multiref_triggers_concatenate=True
)
LOAD_POLICY_COMPREHENSIVE = LoadPolicy(
support_multiple_references=True, use_concatenate=True, repeat_until_done=True
)


def _current_effective_policy():
policy = LOAD_POLICY
if not policy.use_concatenate and policy.multiref_triggers_concatenate:
from iris.fileformats.rules import _MULTIREF_DETECTION

if _MULTIREF_DETECTION.found_multiple_refs:
policy = policy.copy()
policy.use_concatenate = True
return policy


def _apply_loading_policy(cubes, policy=None):
if not policy:
policy = _current_effective_policy()
while True:
n_original_cubes = len(cubes)
if policy.use_concatenate and policy.cat_before_merge:
cubes = cubes.concatenate()
if policy.use_merge:
cubes = cubes.merge()
if policy.use_concatenate and not policy.cat_before_merge:
cubes = cubes.concatenate()
n_new_cubes = len(cubes)
if not policy.repeat_until_done or n_new_cubes >= n_original_cubes:
break

return cubes


def load(uris, constraints=None, callback=None, policy=None):
"""Load any number of Cubes for each constraint.
For a full description of the arguments, please see the module
Expand All @@ -327,7 +447,9 @@ def load(uris, constraints=None, callback=None):
were random.
"""
return _load_collection(uris, constraints, callback).merged().cubes()
cubes = _load_collection(uris, constraints, callback).cubes()
cubes = _apply_loading_policy(cubes)
return cubes


def load_cube(uris, constraint=None, callback=None):
Expand Down
45 changes: 42 additions & 3 deletions lib/iris/fileformats/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Generalised mechanisms for metadata translation and cube construction."""

import collections
import threading
import warnings

import cf_units
Expand Down Expand Up @@ -143,7 +144,11 @@ class _ReferenceError(Exception):


def _dereference_args(factory, reference_targets, regrid_cache, cube):
"""Convert all the arguments for a factory into concrete coordinates."""
"""Convert all the arguments for a factory into concrete coordinates.
Note: where multiple reference fields define an additional dimension, this routine
returns a modified 'cube', with the necessary additional dimensions.
"""
args = []
for arg in factory.args:
if isinstance(arg, Reference):
Expand Down Expand Up @@ -178,6 +183,7 @@ def _dereference_args(factory, reference_targets, regrid_cache, cube):
# If it wasn't a Reference, then arg is a dictionary
# of keyword arguments for cube.coord(...).
args.append(cube.coord(**arg))

return args, cube


Expand Down Expand Up @@ -224,18 +230,24 @@ def _ensure_aligned(regrid_cache, src_cube, target_cube):
# single, distinct dimension.
# PP-MOD: first promote any scalar coords when needed as dims
for target_coord in target_dimcoords:
if not target_cube.coord_dims(target_coord):
from iris import LOAD_POLICY

if (
not target_cube.coord_dims(target_coord)
and LOAD_POLICY.support_multiple_references
):
# The chosen coord is not a dimcoord in the target (yet)
# Make it one with 'new_axis'
from iris.util import new_axis

_MULTIREF_DETECTION.found_multiple_refs = True
# Include the other coords on that dim in the src : this means the
# src merge identifies which belong on that dim
# (e.g. 'forecast_period' along with 'time')
(src_dim,) = src_cube.coord_dims(target_coord) # should have 1 dim
promote_other_coords = [
target_cube.coord(src_coord)
for src_coord in src_cube.coords(dimensions=src_dim)
for src_coord in src_cube.coords(contains_dimension=src_dim)
if src_coord.name() != target_coord.name()
]
target_cube = new_axis(
Expand Down Expand Up @@ -364,9 +376,35 @@ def _resolve_factory_references(
aux_factory = factory.factory_class(*args)
cube.add_aux_factory(aux_factory)

# In the case of multiple references which vary on a new dimension
# (such as time-dependent orography or surface-pressure), the cube may get replaced
# by one with a new dimension.
# In that case we must update the factory so its dependencies are coords of the
# new cube.
cube_coord_ids = [
id(coord) for coord, _ in cube._dim_coords_and_dims + cube._aux_coords_and_dims
]
for factory in cube.aux_factories:
for name, dep in list(factory.dependencies.items()):
if id(dep) not in cube_coord_ids:
factory.update(dep, cube.coord(dep))

return cube


class MultipleReferenceFieldDetector(threading.local):
def __init__(self):
self.found_multiple_refs = False


# A single global object (per thread) to record whether multiple reference fields
# (e.g. time-dependent orography, or surface pressure fields) have been detected during
# the latest load operation.
# This is used purely to implement the iris.LOAD_POLICY.multiref_triggers_concatenate
# functionality.
_MULTIREF_DETECTION = MultipleReferenceFieldDetector()


def _load_pairs_from_fields_and_filenames(
fields_and_filenames, converter, user_callback_wrapper=None
):
Expand All @@ -376,6 +414,7 @@ def _load_pairs_from_fields_and_filenames(
# needs a filename associated with each field to support the load callback.
concrete_reference_targets = {}
results_needing_reference = []

for field, filename in fields_and_filenames:
# Convert the field to a Cube, passing down the 'converter' function.
cube, factories, references = _make_cube(field, converter)
Expand Down

0 comments on commit 23f35d4

Please sign in to comment.