diff --git a/setup.py b/setup.py index 20cae21..64ff0b2 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def read(filename): "chemicals", "click-loguru", "dask", + "deprecation", "distributed", "dill", "dpath", diff --git a/src/pymorize/gather_inputs.py b/src/pymorize/gather_inputs.py index 3ebcff0..6a36af0 100644 --- a/src/pymorize/gather_inputs.py +++ b/src/pymorize/gather_inputs.py @@ -7,18 +7,37 @@ import re from typing import List +import deprecation import dpath +import xarray as xr _PATTERN_ENV_VAR_NAME_ADDR = "/pymorize/pattern_env_var_name" """str: The address in the YAML file which stores the environment variable to be used for the pattern""" _PATTERN_ENV_VAR_NAME_DEFAULT = "PYMORIZE_INPUT_PATTERN" """str: The default value for the environment variable to be used for the pattern""" _PATTERN_ENV_VAR_VALUE_ADDR = "/pymorize/pattern_env_var_value" -"""str: The address in the YAML file which stores the environment variable's value to be used if the variable is not set""" +"""str: The address in the YAML file which stores the environment variable's value""" _PATTERN_ENV_VAR_VALUE_DEFAULT = ".*" # Default: match anything """str: The default value for the environment variable's value to be used if the variable is not set""" +class InputFileCollection: + def __init__(self, path, pattern): + self.path = pathlib.Path(path) + self.pattern = re.compile(pattern) # Compile the regex pattern + + def __iter__(self): + for file in self.path.iterdir(): + if self.pattern.match( + file.name + ): # Check if the filename matches the pattern + yield file + + @classmethod + def from_dict(cls, d): + return cls(d["path"], d["pattern"]) + + def _input_pattern_from_env(config: dict) -> re.Pattern: """ Get the input pattern from the environment variable. @@ -213,6 +232,26 @@ def _validate_rule_has_marked_regex( return all(re.search(rf"\(\?P<{mark}>", pattern) for mark in required_marks) +def load_mfdataset(data, rule_spec): + """ + Load a dataset from a list of files using xarray. + + Parameters + ---------- + data : Any + Data in the pipeline flow thus far. + rule_spec : Rule + Rule being handled + """ + all_files = [] + for file_collection in rule_spec.files: + all_files.append(f for f in file_collection) + all_files = _resolve_symlinks(all_files) + mf_ds = xr.open_mfdataset(all_files, parallel=True, use_cftime=True) + return mf_ds + + +@deprecation.deprecated(details="Use load_mfdataset in your pipeline instead!") def gather_inputs(config: dict) -> dict: """ Gather possible inputs from a user directory. @@ -300,13 +339,13 @@ def gather_inputs(config: dict) -> dict: if year_end is not None: year_end = int(year_end) for input_pattern in input_patterns: - if _validate_rule_has_marked_regex(input_pattern): - pattern = re.compile(input_pattern["pattern"]) + if _validate_rule_has_marked_regex(rule): + pattern = re.compile(rule["pattern"]) else: # FIXME(PG): This needs to be thought through... # If the pattern is not marked, use the environment variable pattern = _input_pattern_from_env(config) - files = _input_files_in_path(input_pattern["path"], pattern) + files = _input_files_in_path(input_pattern, pattern) files = _resolve_symlinks(files) if year_start is not None and year_end is not None: files = _filter_by_year(files, pattern, year_start, year_end) diff --git a/src/pymorize/pipeline.py b/src/pymorize/pipeline.py index c0f4093..603038c 100644 --- a/src/pymorize/pipeline.py +++ b/src/pymorize/pipeline.py @@ -233,7 +233,7 @@ class DefaultPipeline(FrozenPipeline): """ STEPS = ( - "pymorize.generic.load_data", + "pymorize.gather_inputs.load_mfdataset", "pymorize.generic.create_cmor_directories", "pymorize.units.handle_unit_conversion", ) diff --git a/src/pymorize/rule.py b/src/pymorize/rule.py index 9b0e632..e6fcc4e 100644 --- a/src/pymorize/rule.py +++ b/src/pymorize/rule.py @@ -1,20 +1,23 @@ import copy import re import typing +import warnings from collections import OrderedDict # import questionary import yaml from . import data_request, pipeline -from .logging import logger +from .gather_inputs import InputFileCollection + +# from .logging import logger class Rule: def __init__( self, *, - input_patterns: typing.Union[str, typing.List[str]], + inputs: typing.List[dict] = [], cmor_variable: str, pipelines: typing.List[pipeline.Pipeline] = [], tables: typing.List[data_request.DataRequestTable] = [], @@ -28,8 +31,8 @@ def __init__( Parameters ---------- - input_pattern : str or list of str - A regular expression pattern or a list of patterns to match the input file path. + inputs : list of dicts for InputFileCollection + Dictionaries should contain the keys "path" and "pattern". cmor_variable : str The CMOR variable name. This is the name of the variable as it should appear in the CMIP archive. pipelines : list of Pipeline objects @@ -38,19 +41,8 @@ def __init__( A list of data request tables associated with this rule data_request_variables : DataRequestVariable or None : The DataRequestVariables this rule should create - - Raises - ------ - TypeError - If input_pattern is not a string or a list of strings. """ - if isinstance(input_patterns, str): - self.input_patterns = list(re.compile(input_patterns)) - elif isinstance(input_patterns, list): - self.input_patterns = [re.compile(str(p)) for p in input_patterns] - else: - raise TypeError("input_pattern must be a string or a list of strings") - + self.inputs = [InputFileCollection.from_dict(inp_dict) for inp_dict in inputs] self.cmor_variable = cmor_variable self.pipelines = pipelines or [pipeline.DefaultPipeline()] self.tables = tables @@ -67,7 +59,7 @@ def get(self, key, default=None): return getattr(self, key, default) def __repr__(self): - return f"Rule(input_patterns={self.input_patterns}, cmor_variable={self.cmor_variable}, pipelines={self.pipelines})" + return f"Rule(inputs={self.inputs}, cmor_variable={self.cmor_variable}, pipelines={self.pipelines}, tables={self.tables}, data_request_variables={self.data_request_variables})" def __str__(self): return f"Rule for {self.cmor_variable} with input patterns {self.input_patterns} and pipelines {self.pipelines}" @@ -106,7 +98,7 @@ def match_pipelines(self, pipelines, force=False): @classmethod def from_dict(cls, data): return cls( - input_patterns=data.pop("input_patterns"), + inputs=data.pop("inputs"), cmor_variable=data.pop("cmor_variable"), pipelines=data.pop("pipelines", []), **data, @@ -119,7 +111,7 @@ def from_yaml(cls, yaml_str): def to_yaml(self): return yaml.dump( { - "input_patterns": [p.pattern for p in self.input_patterns], + "inputs": [p.to_dict for p in self.input_patterns], "cmor_variable": self.cmor_variable, "pipelines": [p.to_dict() for p in self.pipelines], } @@ -136,6 +128,12 @@ def add_data_request_variable(self, drv): v for v in self.data_request_variable if v is not None ] + @property + def input_patterns(self): + deprecated = "input_patterns is deprecated. Use inputs instead." + warnings.warn(deprecated, DeprecationWarning) + return [re.compile(f"{inp.path}/{inp.pattern}") for inp in self.inputs] + def clone(self): return copy.deepcopy(self) diff --git a/tests/configs/test_config.yaml b/tests/configs/test_config.yaml index 7b1c9e4..81ec601 100644 --- a/tests/configs/test_config.yaml +++ b/tests/configs/test_config.yaml @@ -29,5 +29,8 @@ rules: cmor_variable: "tas" input_type: "xr.DataArray" input_source: "xr_tutorial" - input_patterns: - - "test_input" + inputs: + - path: "./" + pattern: "test_input" + - path: "./some/other/path" + pattern: "test_input2" diff --git a/tests/fixtures/sample_rules.py b/tests/fixtures/sample_rules.py index a19994e..0a356d1 100644 --- a/tests/fixtures/sample_rules.py +++ b/tests/fixtures/sample_rules.py @@ -7,9 +7,15 @@ @pytest.fixture def simple_rule(): return Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + inputs=[ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": "var1_(?P\d{4}).nc", + }, ], cmor_variable="var1", pipelines=["pymorize.pipeline.TestingPipeline"], @@ -19,9 +25,15 @@ def simple_rule(): @pytest.fixture def rule_with_mass_units(): r = Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + inputs=[ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": "var1_(?P\d{4}).nc", + }, ], cmor_variable="var1", pipelines=["pymorize.pipeline.TestingPipeline"], @@ -47,9 +59,15 @@ def rule_with_mass_units(): @pytest.fixture def rule_with_units(): r = Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + inputs=[ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": "var1_(?P\d{4}).nc", + }, ], cmor_variable="var1", pipelines=["pymorize.pipeline.TestingPipeline"], diff --git a/tests/unit/test_rule.py b/tests/unit/test_rule.py index 4eb0c84..54f78ee 100644 --- a/tests/unit/test_rule.py +++ b/tests/unit/test_rule.py @@ -6,18 +6,6 @@ from pymorize.rule import Rule -@pytest.fixture -def simple_rule(): - return Rule( - input_patterns=[ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", - ], - cmor_variable="var1", - pipelines=["pymorize.pipeline.TestingPipeline"], - ) - - def test_direct_init(simple_rule): rule = simple_rule assert all(isinstance(ip, re.Pattern) for ip in rule.input_patterns) @@ -27,9 +15,15 @@ def test_direct_init(simple_rule): def test_from_dict(): data = { - "input_patterns": [ - r"/some/files/containing/var1.*.nc", - r"/some/other/files/containing/var1_(?P\d{4}).nc", + "inputs": [ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": r"var1_(?P\d{4}).nc", + }, ], "cmor_variable": "var1", "pipelines": ["pymorize.pipeline.TestingPipeline"], @@ -42,11 +36,13 @@ def test_from_dict(): def test_from_yaml(): yaml_str = """ - input_patterns: - - /some/files/containing/var1.*.nc - - /some/other/files/containing/var1_(?P\d{4}).nc + inputs: + - path: /some/files/containing/ + pattern: var1.*.nc + - path: /some/other/files/containing/ + pattern: var1_(?P\d{4}).nc cmor_variable: var1 - pipelines: + pipelines: - pymorize.pipeline.TestingPipeline """ rule = Rule.from_yaml(yaml_str)