Skip to content

Commit

Permalink
Fix supplying (base_)compiledir as str
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica authored and ricardoV94 committed Jul 15, 2024
1 parent f35ce26 commit 67e1819
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,14 +1155,14 @@ def _default_compiledirname() -> str:
return safe


def _filter_base_compiledir(path: Path) -> Path:
def _filter_base_compiledir(path: str | Path) -> Path:
# Expand '~' in path
return path.expanduser()
return Path(path).expanduser()


def _filter_compiledir(path: Path) -> Path:
def _filter_compiledir(path: str | Path) -> Path:
# Expand '~' in path
path = path.expanduser()
path = Path(path).expanduser()
# Turn path into the 'real' path. This ensures that:
# 1. There is no relative path, which would fail e.g. when trying to
# import modules from the compile dir.
Expand Down
11 changes: 11 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import configparser as stdlib_configparser
import io
import pickle
from pathlib import Path
from tempfile import mkdtemp

import pytest

Expand All @@ -19,6 +21,15 @@ def _create_test_config():
)


def test_config_paths():
base_compiledir = mkdtemp()
assert configdefaults._filter_base_compiledir(str(base_compiledir)) == Path(
base_compiledir
)
compiledir = mkdtemp()
assert configdefaults._filter_compiledir(str(compiledir)) == Path(compiledir)


def test_invalid_default():
# Ensure an invalid default value found in the PyTensor code only causes
# a crash if it is not overridden by the user.
Expand Down

0 comments on commit 67e1819

Please sign in to comment.