From e05c201c59cbae44f84e9bb66e9984b38756f6f9 Mon Sep 17 00:00:00 2001 From: Paul Gierz Date: Tue, 10 Sep 2024 11:56:36 +0200 Subject: [PATCH 1/5] fix: small mistake in gather inputs --- src/pymorize/gather_inputs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pymorize/gather_inputs.py b/src/pymorize/gather_inputs.py index 3ebcff0..5534b79 100644 --- a/src/pymorize/gather_inputs.py +++ b/src/pymorize/gather_inputs.py @@ -300,13 +300,13 @@ def gather_inputs(config: dict) -> dict: if year_end is not None: year_end = int(year_end) for input_pattern in input_patterns: - if _validate_rule_has_marked_regex(input_pattern): - pattern = re.compile(input_pattern["pattern"]) + if _validate_rule_has_marked_regex(rule): + pattern = re.compile(rule["pattern"]) else: # FIXME(PG): This needs to be thought through... # If the pattern is not marked, use the environment variable pattern = _input_pattern_from_env(config) - files = _input_files_in_path(input_pattern["path"], pattern) + files = _input_files_in_path(input_pattern, pattern) files = _resolve_symlinks(files) if year_start is not None and year_end is not None: files = _filter_by_year(files, pattern, year_start, year_end) From 7647db848f8ca7efb80bacb99d77a4b5b4b9c7a6 Mon Sep 17 00:00:00 2001 From: Paul Gierz Date: Wed, 11 Sep 2024 12:29:12 +0200 Subject: [PATCH 2/5] fix(gather_inputs.py): deprecates the gather_inputs function This will deprecate the gather_inputs function in favour of a pipeline step. At the moment, you just get a deprecation warning if you try to use the function. --- setup.py | 1 + src/pymorize/gather_inputs.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 20cae21..64ff0b2 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def read(filename): "chemicals", "click-loguru", "dask", + "deprecation", "distributed", "dill", "dpath", diff --git a/src/pymorize/gather_inputs.py b/src/pymorize/gather_inputs.py index 5534b79..986e838 100644 --- a/src/pymorize/gather_inputs.py +++ b/src/pymorize/gather_inputs.py @@ -7,6 +7,7 @@ import re from typing import List +import deprecation import dpath _PATTERN_ENV_VAR_NAME_ADDR = "/pymorize/pattern_env_var_name" @@ -213,6 +214,7 @@ def _validate_rule_has_marked_regex( return all(re.search(rf"\(\?P<{mark}>", pattern) for mark in required_marks) +@deprecation.deprecated(details="Use load_mfdataset in your pipeline instead!") def gather_inputs(config: dict) -> dict: """ Gather possible inputs from a user directory. From 8b6824142dc88623b3a7d4c6bd8d09a1acb47f99 Mon Sep 17 00:00:00 2001 From: Paul Gierz Date: Wed, 11 Sep 2024 12:33:31 +0200 Subject: [PATCH 3/5] feat: new API for inputs in Rule Rule has a new API. Instead of input patterns as a list of strings, you now provide inputs, as a list of dictionaries. These should be key/values with path and pattern as keys, later used to get the appropriate file(s) from each path. Tests are also updated to reflect the new way of loading files into the Rules. BREAKING CHANGE: Any signature constructing a Rule object will need to be changed! --- src/pymorize/gather_inputs.py | 37 ++++++++++++++++++++++++++++++++++ src/pymorize/rule.py | 36 ++++++++++++++++----------------- tests/configs/test_config.yaml | 7 +++++-- tests/fixtures/sample_rules.py | 36 ++++++++++++++++++++++++--------- tests/unit/test_rule.py | 34 ++++++++++++++----------------- 5 files changed, 101 insertions(+), 49 deletions(-) diff --git a/src/pymorize/gather_inputs.py b/src/pymorize/gather_inputs.py index 986e838..4d14f67 100644 --- a/src/pymorize/gather_inputs.py +++ b/src/pymorize/gather_inputs.py @@ -9,6 +9,7 @@ import deprecation import dpath +import xarray as xr _PATTERN_ENV_VAR_NAME_ADDR = "/pymorize/pattern_env_var_name" """str: The address in the YAML file which stores the environment variable to be used for the pattern""" @@ -20,6 +21,23 @@ """str: The default value for the environment variable's value to be used if the variable is not set""" +class InputFileCollection: + def __init__(self, path, pattern): + self.path = pathlib.Path(path) + self.pattern = re.compile(pattern) # Compile the regex pattern + + def __iter__(self): + for file in self.path.iterdir(): + if self.pattern.match( + file.name + ): # Check if the filename matches the pattern + yield file + + @classmethod + def from_dict(cls, d): + return cls(d["path"], d["pattern"]) + + def _input_pattern_from_env(config: dict) -> re.Pattern: """ Get the input pattern from the environment variable. @@ -214,6 +232,25 @@ def _validate_rule_has_marked_regex( return all(re.search(rf"\(\?P<{mark}>", pattern) for mark in required_marks) +def load_mfdataset(data, rule_spec): + """ + Load a dataset from a list of files using xarray. + + Parameters + ---------- + data : Any + Data in the pipeline flow thus far. + rule_spec : Rule + Rule being handled + """ + all_files = [] + for file_collection in rule_spec.files: + all_files.append(f for f in file_collection) + all_files = _resolve_symlinks(all_files) + mf_ds = xr.open_mfdataset(all_files, parallel=True, use_cftime=True) + return mf_ds + + @deprecation.deprecated(details="Use load_mfdataset in your pipeline instead!") def gather_inputs(config: dict) -> dict: """ diff --git a/src/pymorize/rule.py b/src/pymorize/rule.py index 9b0e632..e6fcc4e 100644 --- a/src/pymorize/rule.py +++ b/src/pymorize/rule.py @@ -1,20 +1,23 @@ import copy import re import typing +import warnings from collections import OrderedDict # import questionary import yaml from . import data_request, pipeline -from .logging import logger +from .gather_inputs import InputFileCollection + +# from .logging import logger class Rule: def __init__( self, *, - input_patterns: typing.Union[str, typing.List[str]], + inputs: typing.List[dict] = [], cmor_variable: str, pipelines: typing.List[pipeline.Pipeline] = [], tables: typing.List[data_request.DataRequestTable] = [], @@ -28,8 +31,8 @@ def __init__( Parameters ---------- - input_pattern : str or list of str - A regular expression pattern or a list of patterns to match the input file path. + inputs : list of dicts for InputFileCollection + Dictionaries should contain the keys "path" and "pattern". cmor_variable : str The CMOR variable name. This is the name of the variable as it should appear in the CMIP archive. pipelines : list of Pipeline objects @@ -38,19 +41,8 @@ def __init__( A list of data request tables associated with this rule data_request_variables : DataRequestVariable or None : The DataRequestVariables this rule should create - - Raises - ------ - TypeError - If input_pattern is not a string or a list of strings. """ - if isinstance(input_patterns, str): - self.input_patterns = list(re.compile(input_patterns)) - elif isinstance(input_patterns, list): - self.input_patterns = [re.compile(str(p)) for p in input_patterns] - else: - raise TypeError("input_pattern must be a string or a list of strings") - + self.inputs = [InputFileCollection.from_dict(inp_dict) for inp_dict in inputs] self.cmor_variable = cmor_variable self.pipelines = pipelines or [pipeline.DefaultPipeline()] self.tables = tables @@ -67,7 +59,7 @@ def get(self, key, default=None): return getattr(self, key, default) def __repr__(self): - return f"Rule(input_patterns={self.input_patterns}, cmor_variable={self.cmor_variable}, pipelines={self.pipelines})" + return f"Rule(inputs={self.inputs}, cmor_variable={self.cmor_variable}, pipelines={self.pipelines}, tables={self.tables}, data_request_variables={self.data_request_variables})" def __str__(self): return f"Rule for {self.cmor_variable} with input patterns {self.input_patterns} and pipelines {self.pipelines}" @@ -106,7 +98,7 @@ def match_pipelines(self, pipelines, force=False): @classmethod def from_dict(cls, data): return cls( - input_patterns=data.pop("input_patterns"), + inputs=data.pop("inputs"), cmor_variable=data.pop("cmor_variable"), pipelines=data.pop("pipelines", []), **data, @@ -119,7 +111,7 @@ def from_yaml(cls, yaml_str): def to_yaml(self): return yaml.dump( { - "input_patterns": [p.pattern for p in self.input_patterns], + "inputs": [p.to_dict for p in self.input_patterns], "cmor_variable": self.cmor_variable, "pipelines": [p.to_dict() for p in self.pipelines], } @@ -136,6 +128,12 @@ def add_data_request_variable(self, drv): v for v in self.data_request_variable if v is not None ] + @property + def input_patterns(self): + deprecated = "input_patterns is deprecated. Use inputs instead." + warnings.warn(deprecated, DeprecationWarning) + return [re.compile(f"{inp.path}/{inp.pattern}") for inp in self.inputs] + def clone(self): return copy.deepcopy(self) diff --git a/tests/configs/test_config.yaml b/tests/configs/test_config.yaml index 7b1c9e4..81ec601 100644 --- a/tests/configs/test_config.yaml +++ b/tests/configs/test_config.yaml @@ -29,5 +29,8 @@ rules: cmor_variable: "tas" input_type: "xr.DataArray" input_source: "xr_tutorial" - input_patterns: - - "test_input" + inputs: + - path: "./" + pattern: "test_input" + - path: "./some/other/path" + pattern: "test_input2" diff --git a/tests/fixtures/sample_rules.py b/tests/fixtures/sample_rules.py index a19994e..0a356d1 100644 --- a/tests/fixtures/sample_rules.py +++ b/tests/fixtures/sample_rules.py @@ -7,9 +7,15 @@ @pytest.fixture def simple_rule(): return Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + inputs=[ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": "var1_(?P\d{4}).nc", + }, ], cmor_variable="var1", pipelines=["pymorize.pipeline.TestingPipeline"], @@ -19,9 +25,15 @@ def simple_rule(): @pytest.fixture def rule_with_mass_units(): r = Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + inputs=[ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": "var1_(?P\d{4}).nc", + }, ], cmor_variable="var1", pipelines=["pymorize.pipeline.TestingPipeline"], @@ -47,9 +59,15 @@ def rule_with_mass_units(): @pytest.fixture def rule_with_units(): r = Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + inputs=[ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": "var1_(?P\d{4}).nc", + }, ], cmor_variable="var1", pipelines=["pymorize.pipeline.TestingPipeline"], diff --git a/tests/unit/test_rule.py b/tests/unit/test_rule.py index 4eb0c84..54f78ee 100644 --- a/tests/unit/test_rule.py +++ b/tests/unit/test_rule.py @@ -6,18 +6,6 @@ from pymorize.rule import Rule -@pytest.fixture -def simple_rule(): - return Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", - ], - cmor_variable="var1", - pipelines=["pymorize.pipeline.TestingPipeline"], - ) - - def test_direct_init(simple_rule): rule = simple_rule assert all(isinstance(ip, re.Pattern) for ip in rule.input_patterns) @@ -27,9 +15,15 @@ def test_direct_init(simple_rule): def test_from_dict(): data = { - "input_patterns": [ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + "inputs": [ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": r"var1_(?P\d{4}).nc", + }, ], "cmor_variable": "var1", "pipelines": ["pymorize.pipeline.TestingPipeline"], @@ -42,11 +36,13 @@ def test_from_dict(): def test_from_yaml(): yaml_str = """ - input_patterns: - - /some/files/containing/var1.*.nc - - /some/other/files/containing/var1_(?P\d{4}).nc + inputs: + - path: /some/files/containing/ + pattern: var1.*.nc + - path: /some/other/files/containing/ + pattern: var1_(?P\d{4}).nc cmor_variable: var1 - pipelines: + pipelines: - pymorize.pipeline.TestingPipeline """ rule = Rule.from_yaml(yaml_str) From 0e1ce572b665621f46f8d33d9d79e5ba512e371d Mon Sep 17 00:00:00 2001 From: Paul Gierz Date: Wed, 11 Sep 2024 12:35:46 +0200 Subject: [PATCH 4/5] doc: fix too long docstring --- src/pymorize/gather_inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymorize/gather_inputs.py b/src/pymorize/gather_inputs.py index 4d14f67..6a36af0 100644 --- a/src/pymorize/gather_inputs.py +++ b/src/pymorize/gather_inputs.py @@ -16,7 +16,7 @@ _PATTERN_ENV_VAR_NAME_DEFAULT = "PYMORIZE_INPUT_PATTERN" """str: The default value for the environment variable to be used for the pattern""" _PATTERN_ENV_VAR_VALUE_ADDR = "/pymorize/pattern_env_var_value" -"""str: The address in the YAML file which stores the environment variable's value to be used if the variable is not set""" +"""str: The address in the YAML file which stores the environment variable's value""" _PATTERN_ENV_VAR_VALUE_DEFAULT = ".*" # Default: match anything """str: The default value for the environment variable's value to be used if the variable is not set""" From 6b614f4a4fce66ca40834e090437d9e6621ee756 Mon Sep 17 00:00:00 2001 From: Paul Gierz Date: Wed, 11 Sep 2024 12:41:11 +0200 Subject: [PATCH 5/5] feat: forgot to add step in Pipeline --- src/pymorize/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymorize/pipeline.py b/src/pymorize/pipeline.py index c0f4093..603038c 100644 --- a/src/pymorize/pipeline.py +++ b/src/pymorize/pipeline.py @@ -233,7 +233,7 @@ class DefaultPipeline(FrozenPipeline): """ STEPS = ( - "pymorize.generic.load_data", + "pymorize.gather_inputs.load_mfdataset", "pymorize.generic.create_cmor_directories", "pymorize.units.handle_unit_conversion", )