Skip to content

Commit

Permalink
🩹 Fix wrong file loading due to partial filename matching in Project (#…
Browse files Browse the repository at this point in the history
…1212)

This PR fixes an issue with using project due to partial matching of file names where the wrong file was loaded if it started with the same string as the intended file, but is higher in the string sorting order.
E.g. if you had 2 files `model.yml` and `model-target.yml` in your projects `models` folder and would call `project.load_model("model")` it would load `model-target.yml`.
In addition, some code duplications were removed and error messages for missing files were improved.

### Change summary

- [🧪 Changed project tests to fail due to partial filename matching](ba956d6)
  (Note: The "bad" files were added to the markdown repr string since it should show all files)
- [🩹 Fix file loading](29d3dc5)
- [🧪🩹 Fixed tests that only passed due to wrong file matching](171d897)
- [🧪🧹 Reduced duplication in test](d1becfb)
- [👌 Improved error messages on file not found error](8ab8ba3)
  • Loading branch information
s-weigand authored Jan 7, 2023
1 parent 7dafb8f commit ec4a90c
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 53 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- 🩹 Fix pretty_format_numerical for negative values (#1192)
- 🩹 Fix yaml result saving with relative paths (#1199)
- 🩹 Fix model markdown render for items without label (#1213)
- 🩹 Fix wrong file loading due to partial filename matching in Project (#1212)
<!-- Fix within the 0.7.0 release cycle, therefore hidden:
- 🩹 Fix the matrix provider alignment/reduction ('grouping') issues introduced in #1175 (#1190)
-->
Expand Down
43 changes: 23 additions & 20 deletions glotaran/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ def load_data(self, dataset_name: str) -> xr.Dataset | xr.DataArray:
"""
try:
return self._data_registry.load_item(dataset_name)
except ValueError as e:
raise ValueError(f"Dataset {dataset_name!r} does not exist.") from e
except ValueError as err:
raise ValueError(
f"Dataset {dataset_name!r} does not exist. "
f"Known Datasets are: {list(self._data_registry.items.keys())}"
) from err

def import_data(
self,
Expand Down Expand Up @@ -238,8 +241,11 @@ def load_model(self, name: str) -> Model:
"""
try:
return self._model_registry.load_item(name)
except ValueError as e:
raise ValueError(f"Model {name!r} does not exist.") from e
except ValueError as err:
raise ValueError(
f"Model {name!r} does not exist. "
f"Known Models are: {list(self._model_registry.items.keys())}"
) from err

def generate_model(
self,
Expand Down Expand Up @@ -327,8 +333,11 @@ def load_parameters(self, parameters_name: str) -> Parameters:
"""
try:
return self._parameter_registry.load_item(parameters_name)
except ValueError as e:
raise ValueError(f"Parameters '{parameters_name}' does not exist.") from e
except ValueError as err:
raise ValueError(
f"Parameters '{parameters_name}' does not exist. "
f"Known Parameters are: {list(self._parameter_registry.items.keys())}"
) from err

def generate_parameters(
self,
Expand Down Expand Up @@ -417,16 +426,11 @@ def get_result_path(self, result_name: str, *, latest: bool = False) -> Path:
------
ValueError
Raised if result does not exist.
"""
result_name = self._result_registry._latest_result_name_fallback(
result_name, latest=latest
)
path = self._result_registry.directory / result_name
if self._result_registry.is_item(path):
return path
raise ValueError(f"Result {result_name!r} does not exist.")
.. # noqa: DAR402
"""
return self._result_registry._latest_result_path_fallback(result_name, latest=latest)

def get_latest_result_path(self, result_name: str) -> Path:
"""Get the path to a result with name ``name``.
Expand Down Expand Up @@ -471,14 +475,13 @@ def load_result(self, result_name: str, *, latest: bool = False) -> Result:
------
ValueError
Raised if result does not exist.
.. # noqa: DAR402
"""
result_name = self._result_registry._latest_result_name_fallback(
result_name, latest=latest
return self._result_registry._loader(
self._result_registry._latest_result_path_fallback(result_name, latest=latest)
)
try:
return self._result_registry.load_item(result_name)
except ValueError as e:
raise ValueError(f"Result {result_name!r} does not exist.") from e

def load_latest_result(self, result_name: str) -> Result:
"""Load a result.
Expand Down
10 changes: 5 additions & 5 deletions glotaran/project/project_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def load_item(self, name: str) -> Any:
ValueError
Raise if the item does not exist.
"""
try:
path = next(p for p in self._directory.iterdir() if name in p.name)
except StopIteration as e:
raise ValueError(f"No Item with name '{name}' exists.") from e
return self._loader(path)
if name in self.items:
return self._loader(self.items[name])
raise ValueError(
f"No Item with name '{name}' exists. Known items are: {list(self.items.keys())}"
)

def markdown(self, join_indentation: int = 0) -> MarkdownStr:
"""Format the registry items as a markdown text.
Expand Down
26 changes: 19 additions & 7 deletions glotaran/project/project_result_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def previous_result_paths(self, base_name: str) -> list[Path]:
"""
return sorted(self.directory.glob(f"{base_name}_run_*"))

def _latest_result_name_fallback(self, name: str, *, latest: bool = False) -> str:
def _latest_result_path_fallback(self, name: str, *, latest: bool = False) -> Path:
"""Fallback when a user forgets to specify the run to get a result.
If ``name`` contains the run number this will just return ``name``,
Expand All @@ -69,14 +69,20 @@ def _latest_result_name_fallback(self, name: str, *, latest: bool = False) -> st
Parameters
----------
name: str
Name of the result, which should contain the run specifyer.
Name of the result, which should contain the run specifier.
latest: bool
Flag to deactivate warning about using latest result. Defaults to False
Flag to deactivate warning about using latest result. Defaults to False.
Returns
-------
str
Name used to retrieve a result.
Path
Path to the result (latest result if ``name`` does not match the result pattern).
Raises
------
ValueError
Raised if result does not exist.
"""
if re.match(self.result_pattern, name) is None:
if latest is False:
Expand All @@ -89,8 +95,14 @@ def _latest_result_name_fallback(self, name: str, *, latest: bool = False) -> st
stacklevel=3,
)
previous_result_paths = self.previous_result_paths(name) or [Path(name)]
return previous_result_paths[-1].stem
return name
name = previous_result_paths[-1].stem
path = self._directory / name
if self.is_item(path):
return path

raise ValueError(
f"Result {name!r} does not exist. Known Results are: {list(self.items.keys())}"
)

def create_result_run_name(self, base_name: str) -> str:
"""Create a result name for a model.
Expand Down
74 changes: 53 additions & 21 deletions glotaran/project/test/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from importlib.metadata import distribution
from pathlib import Path
from shutil import rmtree
from textwrap import dedent
from typing import Literal

Expand Down Expand Up @@ -51,12 +52,14 @@ def test_init(project_folder: Path, project_file: Path):


def test_create(project_folder: Path):
rmtree(project_folder)
Project.create(project_folder)
with pytest.raises(FileExistsError):
assert Project.create(project_folder)


def test_open(project_folder: Path, project_file: Path):
rmtree(project_folder)
project_from_folder = Project.open(project_folder)

project_from_file = Project.open(project_file)
Expand All @@ -71,6 +74,11 @@ def test_open(project_folder: Path, project_file: Path):
assert not project.has_parameters
assert not project.has_results

# Will cause following tests to fails on bad fuzzy matching due to higher string sort order
(project_folder / "data/dataset_1-bad.nc").touch()
(project_folder / "models/test_model-bad.yml").touch()
(project_folder / "parameters/test_parameters-bad.yml").touch()


def test_open_diff_version(tmp_path: Path):
"""Loading from file overwrites current version."""
Expand Down Expand Up @@ -262,26 +270,25 @@ def test_load_result_warnings(project_folder: Path, project_file: Path):
"""Warn when using fallback to latest result."""
project = Project.open(project_file)

expected_warning_text = (
"Result name 'test' is missing the run specifier, "
"falling back to try getting latest result. "
"Use latest=True to mute this warning."
)

with pytest.warns(UserWarning) as recwarn:
assert project_folder / "results" / "test_run_0001" == project.get_result_path("test")

assert len(recwarn) == 1
assert Path(recwarn[0].filename).samefile(__file__)
assert recwarn[0].message.args[0] == (
"Result name 'test' is missing the run specifier, "
"falling back to try getting latest result. "
"Use latest=True to mute this warning."
)
assert recwarn[0].message.args[0] == expected_warning_text

with pytest.warns(UserWarning) as recwarn:
assert isinstance(project.load_result("test"), Result)

assert len(recwarn) == 1
assert Path(recwarn[0].filename).samefile(__file__)
assert recwarn[0].message.args[0] == (
"Result name 'test' is missing the run specifier, "
"falling back to try getting latest result. "
"Use latest=True to mute this warning."
)
assert recwarn[0].message.args[0] == expected_warning_text


def test_getting_items(project_file: Path):
Expand Down Expand Up @@ -323,7 +330,7 @@ def test_generators_allow_overwrite(project_folder: Path, project_file: Path):
project.generate_model(
"test_model", "decay_parallel", {"nr_compartments": 3}, allow_overwrite=True
)
new_model = project.load_model("test")
new_model = project.load_model("test_model")
assert "megacomplex_parallel_decay" in new_model.megacomplex

comapartments = load_dict(model_file, is_file=True)["megacomplex"][
Expand All @@ -332,13 +339,13 @@ def test_generators_allow_overwrite(project_folder: Path, project_file: Path):

assert len(comapartments) == 3

project.generate_parameters("test", allow_overwrite=True)
project.generate_parameters("test_model", "test_parameters", allow_overwrite=True)
parameters = load_parameters(parameter_file)

assert len(list(filter(lambda p: p.label.startswith("rates"), parameters.all()))) == 3


def test_missing_file_errors(tmp_path: Path):
def test_missing_file_errors(tmp_path: Path, project_folder: Path):
"""Error when accessing non existing files."""
with pytest.raises(FileNotFoundError) as exc_info:
Project.open(tmp_path, create_if_not_exist=False)
Expand All @@ -348,42 +355,64 @@ def test_missing_file_errors(tmp_path: Path):
== f"Project file {(tmp_path/'project.gta').as_posix()} does not exists."
)

project = Project.open(tmp_path)
project = Project.open(project_folder)

with pytest.raises(ValueError) as exc_info:
project.load_data("not-existing")

assert str(exc_info.value) == "Dataset 'not-existing' does not exist."
assert str(exc_info.value) == (
"Dataset 'not-existing' does not exist. "
"Known Datasets are: ['dataset_1', 'dataset_1-bad', 'test_data']"
)

with pytest.raises(ValueError) as exc_info:
project.load_model("not-existing")

assert str(exc_info.value) == "Model 'not-existing' does not exist."
assert str(exc_info.value) == (
"Model 'not-existing' does not exist. "
"Known Models are: ['test_model', 'test_model-bad']"
)

with pytest.raises(ValueError) as exc_info:
project.load_parameters("not-existing")

assert str(exc_info.value) == "Parameters 'not-existing' does not exist."
assert str(exc_info.value) == (
"Parameters 'not-existing' does not exist. "
"Known Parameters are: ['test_parameters', 'test_parameters-bad']"
)

with pytest.raises(ValueError) as exc_info:
project.load_result("not-existing_run_0000")

assert str(exc_info.value) == "Result 'not-existing_run_0000' does not exist."
expected_known_results = (
"Known Results are: "
"['sequential_run_0000', 'sequential_run_0001', 'test_run_0000', 'test_run_0001']"
)

assert str(exc_info.value) == (
f"Result 'not-existing_run_0000' does not exist. {expected_known_results}"
)

with pytest.raises(ValueError) as exc_info:
project.load_latest_result("not-existing")

assert str(exc_info.value) == "Result 'not-existing' does not exist."
assert str(exc_info.value) == (
f"Result 'not-existing' does not exist. {expected_known_results}"
)

with pytest.raises(ValueError) as exc_info:
project.get_result_path("not-existing_run_0000")

assert str(exc_info.value) == "Result 'not-existing_run_0000' does not exist."
assert str(exc_info.value) == (
f"Result 'not-existing_run_0000' does not exist. {expected_known_results}"
)

with pytest.raises(ValueError) as exc_info:
project.get_latest_result_path("not-existing")

assert str(exc_info.value) == "Result 'not-existing' does not exist."
assert str(exc_info.value) == (
f"Result 'not-existing' does not exist. {expected_known_results}"
)


def test_markdown_repr(project_folder: Path, project_file: Path):
Expand All @@ -398,17 +427,20 @@ def test_markdown_repr(project_folder: Path, project_file: Path):
## Data
* dataset_1
* dataset_1-bad
* test_data
## Model
* test_model
* test_model-bad
## Parameters
* test_parameters
* test_parameters-bad
## Results
Expand Down

0 comments on commit ec4a90c

Please sign in to comment.