Skip to content

Commit

Permalink
templates: Add optional fs arg.
Browse files Browse the repository at this point in the history
Replace `rglob` with exatch matching. Fixes potential bug from using `rglob({template_name}*)`

Per https://github.com/iterative/studio/pull/4504
  • Loading branch information
daavoo committed Nov 15, 2022
1 parent 436cee3 commit 3305982
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/dvc_render/vega_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,19 @@ class LinearTemplate(Template):


def _find_template(
template_name: str, template_dir: Optional[str] = None
template_name: str, template_dir: Optional[str] = None, fs=None
) -> Optional["StrPath"]:
_exists = Path.exists if fs is None else fs.exists

if template_dir:
for template_path in Path(template_dir).rglob(f"{template_name}*"):
template_path = Path(template_dir) / template_name
if _exists(template_path):
return template_path
elif _exists(template_path.with_suffix(Template.EXTENSION)):
return template_path.with_suffix(Template.EXTENSION)

template_path = Path(template_name)
if template_path.exists():
if _exists(template_path):
return template_path.resolve()

return None
Expand All @@ -741,13 +746,15 @@ def _find_template(
def get_template(
template: Union[Optional[str], Template] = None,
template_dir: Optional[str] = None,
fs=None,
) -> Template:
"""Return template instance based on given template arg.
If template is already an instance, return it.
If template is None, return default `linear` template.
If template is a path, will try to find it as absolute
path or inside template_dir.
If template is a path, will try to find it:
- Inside `template_dir`
- As a relative path to cwd.
If template matches one of the DEFAULT_NAMEs in TEMPLATES,
return an instance of the one matching.
"""
Expand All @@ -757,10 +764,11 @@ def get_template(
if template is None:
template = "linear"

template_path = _find_template(template, template_dir)
template_path = _find_template(template, template_dir, fs)

_open = open if fs is None else fs.open
if template_path:
with open(template_path, "r", encoding="utf-8") as f:
with _open(template_path, encoding="utf-8") as f:
content = f.read()
return Template(content, name=template)

Expand Down
16 changes: 16 additions & 0 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,27 @@ def test_get_template_from_dir(tmp_dir, template_path, target_name):
)


def test_get_template_exact_match(tmp_dir):
tmp_dir.gen(os.path.join("foodir", "bar_template.json"), "bar")
with pytest.raises(TemplateNotFoundError):
# This was unexpectedly working when using rglob({template_name}*)
# and could cause bugs.
get_template("bar", "foodir")


def test_get_template_from_file(tmp_dir):
tmp_dir.gen("foo/bar.json", "template_content")
assert get_template("foo/bar.json").content == "template_content"


def test_get_template_fs(tmp_dir, mocker):
tmp_dir.gen("foo/bar.json", "template_content")
fs = mocker.MagicMock()
get_template("foo/bar.json", fs=fs)
fs.open.assert_called()
fs.exists.assert_called()


def test_get_default_template():
assert get_template(None).content == LinearTemplate().content

Expand Down

0 comments on commit 3305982

Please sign in to comment.