Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather Inputs pipeline step #33

Merged
5 commits merged into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def read(filename):
"chemicals",
"click-loguru",
"dask",
"deprecation",
"distributed",
"dill",
"dpath",
Expand Down
47 changes: 43 additions & 4 deletions src/pymorize/gather_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,37 @@
import re
from typing import List

import deprecation
import dpath
import xarray as xr

_PATTERN_ENV_VAR_NAME_ADDR = "/pymorize/pattern_env_var_name"
"""str: The address in the YAML file which stores the environment variable to be used for the pattern"""
_PATTERN_ENV_VAR_NAME_DEFAULT = "PYMORIZE_INPUT_PATTERN"
"""str: The default value for the environment variable to be used for the pattern"""
_PATTERN_ENV_VAR_VALUE_ADDR = "/pymorize/pattern_env_var_value"
"""str: The address in the YAML file which stores the environment variable's value to be used if the variable is not set"""
"""str: The address in the YAML file which stores the environment variable's value"""
_PATTERN_ENV_VAR_VALUE_DEFAULT = ".*" # Default: match anything
"""str: The default value for the environment variable's value to be used if the variable is not set"""


class InputFileCollection:
def __init__(self, path, pattern):
self.path = pathlib.Path(path)
self.pattern = re.compile(pattern) # Compile the regex pattern

def __iter__(self):
for file in self.path.iterdir():
if self.pattern.match(
file.name
): # Check if the filename matches the pattern
yield file

@classmethod
def from_dict(cls, d):
return cls(d["path"], d["pattern"])


def _input_pattern_from_env(config: dict) -> re.Pattern:
"""
Get the input pattern from the environment variable.
Expand Down Expand Up @@ -213,6 +232,26 @@ def _validate_rule_has_marked_regex(
return all(re.search(rf"\(\?P<{mark}>", pattern) for mark in required_marks)


def load_mfdataset(data, rule_spec):
"""
Load a dataset from a list of files using xarray.

Parameters
----------
data : Any
Data in the pipeline flow thus far.
rule_spec : Rule
Rule being handled
"""
all_files = []
for file_collection in rule_spec.files:
all_files.append(f for f in file_collection)
all_files = _resolve_symlinks(all_files)
mf_ds = xr.open_mfdataset(all_files, parallel=True, use_cftime=True)
return mf_ds


@deprecation.deprecated(details="Use load_mfdataset in your pipeline instead!")
def gather_inputs(config: dict) -> dict:
"""
Gather possible inputs from a user directory.
Expand Down Expand Up @@ -300,13 +339,13 @@ def gather_inputs(config: dict) -> dict:
if year_end is not None:
year_end = int(year_end)
for input_pattern in input_patterns:
if _validate_rule_has_marked_regex(input_pattern):
pattern = re.compile(input_pattern["pattern"])
if _validate_rule_has_marked_regex(rule):
pattern = re.compile(rule["pattern"])
else:
# FIXME(PG): This needs to be thought through...
# If the pattern is not marked, use the environment variable
pattern = _input_pattern_from_env(config)
files = _input_files_in_path(input_pattern["path"], pattern)
files = _input_files_in_path(input_pattern, pattern)
files = _resolve_symlinks(files)
if year_start is not None and year_end is not None:
files = _filter_by_year(files, pattern, year_start, year_end)
Expand Down
2 changes: 1 addition & 1 deletion src/pymorize/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class DefaultPipeline(FrozenPipeline):
"""

STEPS = (
"pymorize.generic.load_data",
"pymorize.gather_inputs.load_mfdataset",
"pymorize.generic.create_cmor_directories",
"pymorize.units.handle_unit_conversion",
)
Expand Down
36 changes: 17 additions & 19 deletions src/pymorize/rule.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import copy
import re
import typing
import warnings
from collections import OrderedDict

# import questionary
import yaml

from . import data_request, pipeline
from .logging import logger
from .gather_inputs import InputFileCollection

# from .logging import logger


class Rule:
def __init__(
self,
*,
input_patterns: typing.Union[str, typing.List[str]],
inputs: typing.List[dict] = [],
cmor_variable: str,
pipelines: typing.List[pipeline.Pipeline] = [],
tables: typing.List[data_request.DataRequestTable] = [],
Expand All @@ -28,8 +31,8 @@ def __init__(

Parameters
----------
input_pattern : str or list of str
A regular expression pattern or a list of patterns to match the input file path.
inputs : list of dicts for InputFileCollection
Dictionaries should contain the keys "path" and "pattern".
cmor_variable : str
The CMOR variable name. This is the name of the variable as it should appear in the CMIP archive.
pipelines : list of Pipeline objects
Expand All @@ -38,19 +41,8 @@ def __init__(
A list of data request tables associated with this rule
data_request_variables : DataRequestVariable or None :
The DataRequestVariables this rule should create

Raises
------
TypeError
If input_pattern is not a string or a list of strings.
"""
if isinstance(input_patterns, str):
self.input_patterns = list(re.compile(input_patterns))
elif isinstance(input_patterns, list):
self.input_patterns = [re.compile(str(p)) for p in input_patterns]
else:
raise TypeError("input_pattern must be a string or a list of strings")

self.inputs = [InputFileCollection.from_dict(inp_dict) for inp_dict in inputs]
self.cmor_variable = cmor_variable
self.pipelines = pipelines or [pipeline.DefaultPipeline()]
self.tables = tables
Expand All @@ -67,7 +59,7 @@ def get(self, key, default=None):
return getattr(self, key, default)

def __repr__(self):
return f"Rule(input_patterns={self.input_patterns}, cmor_variable={self.cmor_variable}, pipelines={self.pipelines})"
return f"Rule(inputs={self.inputs}, cmor_variable={self.cmor_variable}, pipelines={self.pipelines}, tables={self.tables}, data_request_variables={self.data_request_variables})"

def __str__(self):
return f"Rule for {self.cmor_variable} with input patterns {self.input_patterns} and pipelines {self.pipelines}"
Expand Down Expand Up @@ -106,7 +98,7 @@ def match_pipelines(self, pipelines, force=False):
@classmethod
def from_dict(cls, data):
return cls(
input_patterns=data.pop("input_patterns"),
inputs=data.pop("inputs"),
cmor_variable=data.pop("cmor_variable"),
pipelines=data.pop("pipelines", []),
**data,
Expand All @@ -119,7 +111,7 @@ def from_yaml(cls, yaml_str):
def to_yaml(self):
return yaml.dump(
{
"input_patterns": [p.pattern for p in self.input_patterns],
"inputs": [p.to_dict for p in self.input_patterns],
"cmor_variable": self.cmor_variable,
"pipelines": [p.to_dict() for p in self.pipelines],
}
Expand All @@ -136,6 +128,12 @@ def add_data_request_variable(self, drv):
v for v in self.data_request_variable if v is not None
]

@property
def input_patterns(self):
deprecated = "input_patterns is deprecated. Use inputs instead."
warnings.warn(deprecated, DeprecationWarning)
return [re.compile(f"{inp.path}/{inp.pattern}") for inp in self.inputs]

def clone(self):
return copy.deepcopy(self)

Expand Down
7 changes: 5 additions & 2 deletions tests/configs/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,8 @@ rules:
cmor_variable: "tas"
input_type: "xr.DataArray"
input_source: "xr_tutorial"
input_patterns:
- "test_input"
inputs:
- path: "./"
pattern: "test_input"
- path: "./some/other/path"
pattern: "test_input2"
36 changes: 27 additions & 9 deletions tests/fixtures/sample_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
@pytest.fixture
def simple_rule():
return Rule(
input_patterns=[
r"/some/files/containing/var1.*.nc",
r"/some/other/files/containing/var1_(?P<year>\d{4}).nc",
inputs=[
{
"path": "/some/files/containing/",
"pattern": "var1.*.nc",
},
{
"path": "/some/other/files/containing/",
"pattern": "var1_(?P<year>\d{4}).nc",
},
],
cmor_variable="var1",
pipelines=["pymorize.pipeline.TestingPipeline"],
Expand All @@ -19,9 +25,15 @@ def simple_rule():
@pytest.fixture
def rule_with_mass_units():
r = Rule(
input_patterns=[
r"/some/files/containing/var1.*.nc",
r"/some/other/files/containing/var1_(?P<year>\d{4}).nc",
inputs=[
{
"path": "/some/files/containing/",
"pattern": "var1.*.nc",
},
{
"path": "/some/other/files/containing/",
"pattern": "var1_(?P<year>\d{4}).nc",
},
],
cmor_variable="var1",
pipelines=["pymorize.pipeline.TestingPipeline"],
Expand All @@ -47,9 +59,15 @@ def rule_with_mass_units():
@pytest.fixture
def rule_with_units():
r = Rule(
input_patterns=[
r"/some/files/containing/var1.*.nc",
r"/some/other/files/containing/var1_(?P<year>\d{4}).nc",
inputs=[
{
"path": "/some/files/containing/",
"pattern": "var1.*.nc",
},
{
"path": "/some/other/files/containing/",
"pattern": "var1_(?P<year>\d{4}).nc",
},
],
cmor_variable="var1",
pipelines=["pymorize.pipeline.TestingPipeline"],
Expand Down
34 changes: 15 additions & 19 deletions tests/unit/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,6 @@
from pymorize.rule import Rule


@pytest.fixture
def simple_rule():
return Rule(
input_patterns=[
r"/some/files/containing/var1.*.nc",
r"/some/other/files/containing/var1_(?P<year>\d{4}).nc",
],
cmor_variable="var1",
pipelines=["pymorize.pipeline.TestingPipeline"],
)


def test_direct_init(simple_rule):
rule = simple_rule
assert all(isinstance(ip, re.Pattern) for ip in rule.input_patterns)
Expand All @@ -27,9 +15,15 @@ def test_direct_init(simple_rule):

def test_from_dict():
data = {
"input_patterns": [
r"/some/files/containing/var1.*.nc",
r"/some/other/files/containing/var1_(?P<year>\d{4}).nc",
"inputs": [
{
"path": "/some/files/containing/",
"pattern": "var1.*.nc",
},
{
"path": "/some/other/files/containing/",
"pattern": r"var1_(?P<year>\d{4}).nc",
},
],
"cmor_variable": "var1",
"pipelines": ["pymorize.pipeline.TestingPipeline"],
Expand All @@ -42,11 +36,13 @@ def test_from_dict():

def test_from_yaml():
yaml_str = """
input_patterns:
- /some/files/containing/var1.*.nc
- /some/other/files/containing/var1_(?P<year>\d{4}).nc
inputs:
- path: /some/files/containing/
pattern: var1.*.nc
- path: /some/other/files/containing/
pattern: var1_(?P<year>\d{4}).nc
cmor_variable: var1
pipelines:
pipelines:
- pymorize.pipeline.TestingPipeline
"""
rule = Rule.from_yaml(yaml_str)
Expand Down
Loading