Skip to content

Commit

Permalink
Merge pull request #114 from Deltares/DEI-185-depth-average-rule
Browse files Browse the repository at this point in the history
Dei 185 depth average rule
  • Loading branch information
mKlapwijk authored Jul 8, 2024
2 parents de6831f + a423fe3 commit 7334cdb
Show file tree
Hide file tree
Showing 83 changed files with 982 additions and 209 deletions.
59 changes: 35 additions & 24 deletions decoimpact/business/entities/rule_based_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

from typing import List, Optional
from typing import Dict, List, Optional

import xarray as _xr

Expand Down Expand Up @@ -104,9 +104,8 @@ def initialize(self, logger: ILogger) -> None:
self._input_datasets, self._make_output_variables_list(), self._mappings
)
self._rule_processor = RuleProcessor(self._rules, self._output_dataset)
success = self._rule_processor.initialize(logger)

if not success:
if not self._rule_processor.initialize(logger):
logger.log_error("Initialization failed.")

def execute(self, logger: ILogger) -> None:
Expand Down Expand Up @@ -142,8 +141,11 @@ def _make_output_variables_list(self) -> list:
var_list = _du.get_dummy_and_dependent_var_list(dataset)

mapping_keys = list((self._mappings or {}).keys())
rule_names = [rule.name for rule in self._rules]
all_inputs = self._get_direct_rule_inputs(rule_names)
all_input_variables = _lu.flatten_list(list(all_inputs.values()))

all_vars = var_list + mapping_keys + self._get_direct_rule_inputs()
all_vars = var_list + mapping_keys + all_input_variables
return _lu.remove_duplicates_from_list(all_vars)

def _validate_mappings(self, mappings: dict[str, str], logger: ILogger) -> bool:
Expand All @@ -157,7 +159,10 @@ def _validate_mappings(self, mappings: dict[str, str], logger: ILogger) -> bool:
bool: if mappings are valid
"""
input_vars = _lu.flatten_list(
[_du.list_vars(ds) for ds in self._input_datasets]
[
_lu.flatten_list([_du.list_vars(ds), _du.list_coords(ds)])
for ds in self._input_datasets
]
)

valid = True
Expand All @@ -174,8 +179,7 @@ def _validate_mappings(self, mappings: dict[str, str], logger: ILogger) -> bool:
valid = False

# check for duplicates that will be created because of mapping
mapping_vars_created = list(mappings.values())
duplicates_created = _lu.items_in(mapping_vars_created, input_vars)
duplicates_created = _lu.items_in(list(mappings.values()), input_vars)

if len(duplicates_created) > 0:
logger.log_error(
Expand All @@ -185,31 +189,38 @@ def _validate_mappings(self, mappings: dict[str, str], logger: ILogger) -> bool:
)
valid = False

rule_names = [rule.name for rule in self._rules]

rule_inputs = self._get_direct_rule_inputs(rule_names)

# check for missing rule inputs
needed_rule_inputs = _lu.remove_duplicates_from_list(
self._get_direct_rule_inputs()
)
rule_input_vars = input_vars + mapping_vars_created
missing_rule_inputs = _lu.items_not_in(needed_rule_inputs, rule_input_vars)
if len(missing_rule_inputs) > 0:
logger.log_error(
f"Missing the variables '{', '.join(missing_rule_inputs)}' that "
"are required by some rules."
)
valid = False
for rule_name, rule_input in rule_inputs.items():
needed_rule_inputs = _lu.remove_duplicates_from_list(rule_input)
rule_input_vars = input_vars + list(mappings.values())
missing_rule_inputs = _lu.items_not_in(needed_rule_inputs, rule_input_vars)
if len(missing_rule_inputs) > 0:
logger.log_error(
f"Missing the variables '{', '.join(missing_rule_inputs)}' that "
f"are required by '{rule_name}'."
)
valid = False

return valid

def _get_direct_rule_inputs(self) -> List[str]:
def _get_direct_rule_inputs(self, rule_names) -> Dict[str, List[str]]:
"""Gets the input variables directly needed by rules from
input datasets.
Returns:
List[str]:
Dict[str, List[str]]
"""
rule_input_vars = _lu.flatten_list(
[rule.input_variable_names for rule in self._rules]
)
rule_input_vars = [rule.input_variable_names for rule in self._rules]
rule_output_vars = [rule.output_variable_name for rule in self._rules]

return _lu.items_not_in(rule_input_vars, rule_output_vars)
needed_input_per_rule = {}
for index, inputs_per_rule in enumerate(rule_input_vars):
needed_input_per_rule[rule_names[index]] = _lu.items_not_in(
inputs_per_rule, rule_output_vars
)

return needed_input_per_rule
7 changes: 5 additions & 2 deletions decoimpact/business/entities/rule_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as _np
import xarray as _xr
import decoimpact.business.utils.dataset_utils as _du
import decoimpact.business.utils.list_utils as _lu

from decoimpact.business.entities.rules.i_array_based_rule import IArrayBasedRule
from decoimpact.business.entities.rules.i_cell_based_rule import ICellBasedRule
Expand Down Expand Up @@ -63,8 +65,9 @@ def initialize(self, logger: ILogger) -> bool:
"""
inputs: List[str] = []

inputs = [str(key) for key in self._input_dataset]

inputs = _lu.flatten_list(
[_du.list_vars(self._input_dataset), _du.list_coords(self._input_dataset)]
)
tree, success = self._create_rule_sets(inputs, self._rules, [], logger)
if success:
self._processing_list = tree
Expand Down
2 changes: 1 addition & 1 deletion decoimpact/business/entities/rules/axis_filter_rule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is part of D-EcoImpact
# Copyright (C) 2022-2023 Stichting Deltares and D-EcoImpact contributors
# Copyright (C) 2022-2024 Stichting Deltares and D-EcoImpact contributors
# This program is free software distributed under the GNU
# Lesser General Public License version 2.1
# A copy of the GNU General Public License can be found at
Expand Down
108 changes: 108 additions & 0 deletions decoimpact/business/entities/rules/depth_average_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# This file is part of D-EcoImpact
# Copyright (C) 2022-2024 Stichting Deltares
# This program is free software distributed under the
# GNU Affero General Public License version 3.0
# A copy of the GNU Affero General Public License can be found at
# https://github.com/Deltares/D-EcoImpact/blob/main/LICENSE.md
"""
Module for DepthAverageRule class
Classes:
DepthAverageRule
"""
from typing import Dict

import xarray as _xr

from decoimpact.business.entities.rules.i_multi_array_based_rule import (
IMultiArrayBasedRule,
)
from decoimpact.business.entities.rules.rule_base import RuleBase
from decoimpact.crosscutting.i_logger import ILogger
from decoimpact.crosscutting.delft3d_specific_data import (
INTERFACES_NAME,
BED_LEVEL_NAME,
WATER_LEVEL_NAME,
)


class DepthAverageRule(RuleBase, IMultiArrayBasedRule):
"""Implementation for the depth average rule"""

def execute(
self, value_arrays: Dict[str, _xr.DataArray], logger: ILogger
) -> _xr.DataArray:
"""Calculate depth average of assumed z-layers.
Args:
value_array (DataArray): Values to multiply
Returns:
DataArray: Averaged values
"""

# The first DataArray in our value_arrays contains the values to be averaged
# but the name of the key is given by the user, and is unknown here, so
# just used the first value.
variables = next(iter(value_arrays.values()))

# depths interfaces = borders of the layers in terms of depth
depths_interfaces = value_arrays[INTERFACES_NAME]
water_level_values = value_arrays[WATER_LEVEL_NAME]
bed_level_values = value_arrays[BED_LEVEL_NAME]

# Get the dimension names for the interfaces and for the layers
dim_interfaces_name = list(depths_interfaces.dims)[0]
interfaces_len = depths_interfaces[dim_interfaces_name].size

dim_layer_name = [
d for d in variables.dims if d not in water_level_values.dims
][0]
layer_len = variables[dim_layer_name].size

# interface dimension should always be one larger than layer dimension
# Otherwise give an error to the user
if interfaces_len != layer_len + 1:
logger.log_error(
f"The number of interfaces should be number of layers + 1. Number of"
f"interfaces = {interfaces_len}. Number of layers = {layer_len}."
)
return variables

# Deal with open layer system at water level and bed level
depths_interfaces.values[depths_interfaces.values.argmin()] = -100000
depths_interfaces.values[depths_interfaces.values.argmax()] = 100000

# Broadcast the depths to the dimensions of the bed levels. Then make a
# correction for the depths to the bed level, in other words all depths lower
# than the bed level will be corrected to the bed level.
depths_interfaces_broadcasted = depths_interfaces.broadcast_like(
bed_level_values
)

corrected_depth_bed = depths_interfaces_broadcasted.where(
bed_level_values < depths_interfaces_broadcasted, bed_level_values
)

# Make a similar correction for the waterlevels (first broadcast to match
# dimensions and then replace all values higher than waterlevel with
# waterlevel)
corrected_depth_bed = corrected_depth_bed.broadcast_like(water_level_values)
corrected_depth_bed = corrected_depth_bed.where(
water_level_values > corrected_depth_bed, water_level_values
)

# Calculate the layer heights between depths
layer_heights = corrected_depth_bed.diff(dim=dim_interfaces_name)
layer_heights = layer_heights.rename({dim_interfaces_name: dim_layer_name})

# Use the NaN filtering of the variables to set the correct depth per column
layer_heights = layer_heights.where(variables.notnull())

# Calculate depth average using relative value
relative_values = variables * layer_heights

# Calculate average
return relative_values.sum(dim=dim_layer_name) / layer_heights.sum(
dim=dim_layer_name
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is part of D-EcoImpact
# Copyright (C) 2022-2023 Stichting Deltares and D-EcoImpact contributors
# Copyright (C) 2022-2024 Stichting Deltares and D-EcoImpact contributors
# This program is free software distributed under the GNU
# Lesser General Public License version 2.1
# A copy of the GNU General Public License can be found at
Expand Down
14 changes: 13 additions & 1 deletion decoimpact/business/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def list_vars(dataset: _xr.Dataset) -> list[str]:
return list((dataset.data_vars or {}).keys())


def list_coords(dataset: _xr.Dataset) -> list[str]:
"""List coordinates in dataset
Args:
dataset (_xr.Dataset): Dataset to list variables from
Returns:
list_variables
"""
return list((dataset.coords or {}).keys())


def copy_dataset(dataset: _xr.Dataset) -> _xr.Dataset:
"""Copy dataset to new dataset
Expand Down Expand Up @@ -211,7 +223,7 @@ def get_dummy_variable_in_ugrid(dataset: _xr.Dataset) -> list:

def get_dummy_and_dependent_var_list(dataset: _xr.Dataset) -> list:
"""Obtain the list of variables in a dataset.
The dummy variable is obtained, from which a the variables are
The dummy variable is obtained, from which the variables are
recursively looked up. The dummy and dependent variables are combined
in one list.
This is done to support XUgrid and to prevent invalid topologies.
Expand Down
12 changes: 11 additions & 1 deletion decoimpact/business/workflow/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from decoimpact.business.entities.rules.axis_filter_rule import AxisFilterRule
from decoimpact.business.entities.rules.classification_rule import ClassificationRule
from decoimpact.business.entities.rules.combine_results_rule import CombineResultsRule
from decoimpact.business.entities.rules.depth_average_rule import DepthAverageRule
from decoimpact.business.entities.rules.formula_rule import FormulaRule
from decoimpact.business.entities.rules.i_rule import IRule
from decoimpact.business.entities.rules.layer_filter_rule import LayerFilterRule
Expand All @@ -39,6 +40,7 @@
from decoimpact.data.api.i_classification_rule_data import IClassificationRuleData
from decoimpact.data.api.i_combine_results_rule_data import ICombineResultsRuleData
from decoimpact.data.api.i_data_access_layer import IDataAccessLayer
from decoimpact.data.api.i_depth_average_rule_data import IDepthAverageRuleData
from decoimpact.data.api.i_formula_rule_data import IFormulaRuleData
from decoimpact.data.api.i_layer_filter_rule_data import ILayerFilterRuleData
from decoimpact.data.api.i_model_data import IModelData
Expand Down Expand Up @@ -90,13 +92,22 @@ def _set_default_fields(rule_data: IRuleData, rule: RuleBase):

@staticmethod
def _create_rule(rule_data: IRuleData) -> IRule:

# from python >3.10 we can use match/case, better solution
# until then disable pylint.
# pylint: disable=too-many-branches
if isinstance(rule_data, IMultiplyRuleData):
rule = MultiplyRule(
rule_data.name,
[rule_data.input_variable],
rule_data.multipliers,
rule_data.date_range,
)
elif isinstance(rule_data, IDepthAverageRuleData):
rule = DepthAverageRule(
rule_data.name,
rule_data.input_variables,
)
elif isinstance(rule_data, ILayerFilterRuleData):
rule = LayerFilterRule(
rule_data.name,
Expand Down Expand Up @@ -162,5 +173,4 @@ def _create_rule(rule_data: IRuleData) -> IRule:

if isinstance(rule, RuleBase):
ModelBuilder._set_default_fields(rule_data, rule)

return rule
13 changes: 13 additions & 0 deletions decoimpact/crosscutting/delft3d_specific_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# This file is part of D-EcoImpact
# Copyright (C) 2022-2024 Stichting Deltares
# This program is free software distributed under the
# GNU Affero General Public License version 3.0
# A copy of the GNU Affero General Public License can be found at
# https://github.com/Deltares/D-EcoImpact/blob/main/LICENSE.md
"""
Configuration file for hardcoded delft3d variable names
"""

INTERFACES_NAME = "mesh2d_interface_z"
BED_LEVEL_NAME = "mesh2d_flowelem_bl"
WATER_LEVEL_NAME = "mesh2d_s1"
28 changes: 28 additions & 0 deletions decoimpact/data/api/i_depth_average_rule_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This file is part of D-EcoImpact
# Copyright (C) 2022-2024 Stichting Deltares
# This program is free software distributed under the
# GNU Affero General Public License version 3.0
# A copy of the GNU Affero General Public License can be found at
# https://github.com/Deltares/D-EcoImpact/blob/main/LICENSE.md
"""
Module for IDepthAverageRuleData interface
Interfaces:
IDepthAverageRuleData
"""


from abc import ABC, abstractmethod
from typing import List

from decoimpact.data.api.i_rule_data import IRuleData


class IDepthAverageRuleData(IRuleData, ABC):
"""Data for a DepthAverageRule"""

@property
@abstractmethod
def input_variables(self) -> List[str]:
"""List with input variable name and standard depth name"""
2 changes: 1 addition & 1 deletion decoimpact/data/api/i_rolling_statistics_rule_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is part of D-EcoImpact
# Copyright (C) 2022-2023 Stichting Deltares and D-EcoImpact contributors
# Copyright (C) 2022-2024 Stichting Deltares and D-EcoImpact contributors
# This program is free software distributed under the GNU
# Lesser General Public License version 2.1
# A copy of the GNU General Public License can be found at
Expand Down
Loading

0 comments on commit 7334cdb

Please sign in to comment.