diff --git a/thx/config.py b/thx/config.py index 7f4cacd..950669b 100644 --- a/thx/config.py +++ b/thx/config.py @@ -7,7 +7,7 @@ import tomli from trailrunner.core import project_root -from .types import Config, ConfigError, Job, Version +from .types import Builder, Config, ConfigError, Job, Version def ensure_dict(value: Any, key: str) -> Dict[str, Any]: @@ -134,6 +134,15 @@ def load_config(path: Optional[Path] = None) -> Config: content = pyproject.read_text() data = tomli.loads(content).get("tool", {}).get("thx", {}) + try: + builder_str = data.pop("builder", Builder.AUTO.value) + builder = Builder(builder_str) + except ValueError: + raise ConfigError( + f"Option tool.thx.builder: invalid value {builder_str!r}; " + f"expected one of {', '.join(b.value for b in Builder)}" + ) + default: List[str] = ensure_listish(data.pop("default", None), "tool.thx.default") jobs: List[Job] = parse_jobs(data.pop("jobs", {})) versions: List[Version] = sorted( @@ -170,6 +179,7 @@ def load_config(path: Optional[Path] = None) -> Config: requirements=requirements, extras=extras, watch_paths=watch_paths, + builder=builder, ) ) diff --git a/thx/context.py b/thx/context.py index 650b417..ad22179 100644 --- a/thx/context.py +++ b/thx/context.py @@ -2,10 +2,12 @@ # Licensed under the MIT License import logging +import os import platform import re import shutil import subprocess +import sys import time from itertools import chain from pathlib import Path @@ -15,8 +17,10 @@ from .runner import check_command from .types import ( + Builder, CommandError, Config, + ConfigError, Context, Event, Options, @@ -35,10 +39,17 @@ def venv_path(config: Config, version: Version) -> Path: + """ + Return the path for the desired virtual environment for the given version. + """ return config.root / ".thx" / "venv" / str(version) def runtime_version(binary: Path) -> Optional[Version]: + """Load the version printed by the given Python interpreter. + + Cache the result to avoid repeated calls. + """ if binary not in PYTHON_VERSIONS: try: proc = subprocess.run( @@ -78,6 +89,10 @@ def runtime_version(binary: Path) -> Optional[Version]: def find_runtime( version: Version, venv: Optional[Path] = None ) -> Tuple[Optional[Path], Optional[Version]]: + """ + Locate a Python interpreter matching the desired `version`. If `venv` is provided + and is a directory, look for its Python. Otherwise, try typical binary names. + """ if venv and venv.is_dir(): bin_dir = venv_bin_path(venv) binary_path_str = shutil.which("python", path=bin_dir.as_posix()) @@ -108,12 +123,90 @@ def find_runtime( return None, None +def identify_venv(venv_path: Path) -> Tuple[Path, Version]: + """Read the pyvenv.cfg from a venv to determine the Python version. + + Return a path to the Python interpreter and the version of that interpreter. + """ + cfg = venv_path / "pyvenv.cfg" + + try: + f = cfg.open() + except FileNotFoundError: + raise ConfigError(f"venv {venv_path} is missing pyvenv.cfg.") from None + + # Canonical parsing of pyvenv.cfg is here: + # https://github.com/python/cpython/blob/e65a1eb93ae35f9fbab1508606e3fbc89123629f/Modules/getpath.py#L372 + # The file is a simple key=value format and any lines that are malformed + # are ignored. + VERSION_KEYS = ( + "version_info", # uv + "version", # venv + ) + kvs = {} + version = None + with f: + for line in f: + key, eq, value = line.partition("=") + if eq and key.strip().lower() in VERSION_KEYS: + version = Version(value.strip()) + break + elif eq: + kvs[key.strip()] = value.strip() + + if version is None: + raise ConfigError( + f"pyvenv.cfg in venv {venv_path} does not contain version: {kvs}" + ) + + bin_dir = venv_bin_path(venv_path) + candidates = [ + f"python{version.major}.{version.minor}", + f"python{version.major}", + "python", + ] + for candidate in candidates: + python_path = bin_dir / candidate + if python_path.exists(): + break + else: + raise ConfigError(f"venv {venv_path} does not contain a Python interpreter") + return python_path, version + + @timed("resolve contexts") def resolve_contexts(config: Config, options: Options) -> List[Context]: + """Build a list of contexts in which to run. + + We evaluate the list of Python versions from config, as well as + command-line options refining the list. + """ + builder = determine_builder(config) + if options.live or not config.versions: version = Version(platform.python_version().rstrip("+")) # defer resolving python path to after venv creation - return [Context(version, Path(""), venv_path(config, version), live=True)] + return [ + Context( + version, + Path(sys.executable), + venv_path(config, version), + builder, + live=True, + ) + ] + + if builder == Builder.UV: + # If using uv we can let uv resolve the Python path for each version, + # which may involve installing a new Python version. + + versions = config.versions + if options.python is not None: + versions = version_match(config.versions, options.python) + return [ + Context(version, None, venv_path(config, version), builder) + for version in versions + ] contexts: List[Context] = [] missing_versions: List[Version] = [] @@ -124,7 +217,7 @@ def resolve_contexts(config: Config, options: Options) -> List[Context]: missing_versions.append(version) else: venv = venv_path(config, runtime_version) - contexts.append(Context(runtime_version, runtime_path, venv)) + contexts.append(Context(runtime_version, runtime_path, venv, builder)) if missing_versions: LOG.warning("missing Python versions: %r", [str(v) for v in missing_versions]) @@ -144,7 +237,11 @@ def resolve_contexts(config: Config, options: Options) -> List[Context]: def project_requirements(config: Config) -> Sequence[Path]: - """Get a list of Path objects for configured or discovered requirements files""" + """Return a list of requirements file paths for the project. + + If config.requirements is given, use those paths. Otherwise, search for + requirements*.txt files in the project root. + """ paths: List[Path] = [] if config.requirements: paths += [(config.root / req) for req in config.requirements] @@ -154,7 +251,11 @@ def project_requirements(config: Config) -> Sequence[Path]: def needs_update(context: Context, config: Config) -> bool: - """Compare timestamps of marker file and requirements files""" + """Return True if the environment needs to be rebuilt. + + We currently do this by comparing the modification time of all requirements + files to a stored timestamp file inside the venv. + """ try: timestamp = context.venv / TIMESTAMP if timestamp.exists(): @@ -178,75 +279,179 @@ def needs_update(context: Context, config: Config) -> bool: context.venv, exc_info=True, ) - return True @timed("prepare virtualenv") async def prepare_virtualenv(context: Context, config: Config) -> AsyncIterator[Event]: - """Setup virtualenv and install packages""" + """ + Prepare the virtual environment, either using pip or uv logic, + depending on config.builder (or auto). + """ + if needs_update(context, config): + LOG.info("preparing virtualenv %s", context.venv) + yield VenvCreate(context, message="creating virtualenv") + + builder = context.builder + if builder == Builder.UV: + task = prepare_virtualenv_uv(context, config) + elif builder == Builder.PIP: + task = prepare_virtualenv_pip(context, config) + else: + raise ConfigError(f"Unknown builder: {builder}") + async for event in task: + yield event + else: + LOG.debug("reusing existing virtualenv %s", context.venv) + yield VenvReady(context) - try: - if needs_update(context, config): - LOG.info("preparing virtualenv %s", context.venv) - yield VenvCreate(context, message="creating virtualenv") - - # create virtualenv - prompt = f"thx-{context.python_version}" - if context.live: - import venv - - venv.create(context.venv, prompt=prompt, with_pip=True) - - else: - await check_command( - [ - context.python_path, - "-m", - "venv", - "--prompt", - prompt, - context.venv, - ] - ) - new_python_path, new_python_version = find_runtime( - context.python_version, context.venv +def determine_builder(config: Config) -> Builder: + """Resolve which builder to use. + + If a builder is explicitly configured, attempt to use it (and fail if it + is unavailable.) + + If builder is auto, pick uv if available, else pip. + """ + uv = shutil.which("uv") + if config.builder == Builder.AUTO: + if uv is not None: + return Builder.UV + return Builder.PIP + if config.builder == Builder.UV: + if uv is None: + raise ConfigError("uv not found on PATH, cannot build with uv") + return config.builder + + +async def prepare_virtualenv_pip( + context: Context, config: Config +) -> AsyncIterator[Event]: + """Create and populate a venv using venv and pip.""" + try: + # Create the venv + if context.live: + import venv + + venv.create( + context.venv, + prompt=f"thx-{context.python_version}", + with_pip=True, + symlinks=(os.name != "nt"), ) - context.python_path = new_python_path or context.python_path - context.python_version = new_python_version or context.python_version + else: + assert ( + context.python_path is not None + ), "python_path must be resolved for non-live venv with pip" + await check_command( + [ + context.python_path, + "-m", + "venv", + "--prompt", + f"thx-{context.python_version}", + context.venv, + ] + ) + + # Update runtime in context + context.python_path, context.python_version = identify_venv(context.venv) + + # Upgrade pip, setuptools + yield VenvCreate(context, message="upgrading pip") + await check_command( + [ + context.python_path, + "-m", + "pip", + "install", + "-U", + "pip", + "setuptools", + ] + ) + pip = which("pip", context) + + # Install requirements + requirements = project_requirements(config) + if requirements: + yield VenvCreate(context, message="installing requirements") + LOG.debug("installing deps from %s", requirements) + cmd: List[StrPath] = [pip, "install", "-U"] + for requirement in requirements: + cmd.extend(["-r", requirement]) + await check_command(cmd) + + # Install local project + yield VenvCreate(context, message="installing project") + if config.extras: + proj = f"{config.root}[{','.join(config.extras)}]" + else: + proj = str(config.root) + await check_command([pip, "install", "-U", proj]) + + # Record a timestamp + (context.venv / TIMESTAMP).write_text(f"{time.time_ns()}\n") - # upgrade pip - yield VenvCreate(context, message="upgrading pip") + yield VenvReady(context) + + except CommandError as error: + yield VenvError(context, error) + + +async def prepare_virtualenv_uv( + context: Context, config: Config +) -> AsyncIterator[Event]: + """Create and populate a venv using uv.""" + try: + # Create the venv with uv + uv = shutil.which("uv") + if not uv: + raise ConfigError("uv not found on PATH, cannot build with uv") + + await check_command( + [ + uv, + "venv", + f"--prompt=thx-{context.python_version}", + "-p", + ( + str(context.python_path) + if context.python_path + else str(context.python_version) + ), + str(context.venv), + ] + ) + + context.python_path, context.python_version = identify_venv(context.venv) + + # Install requirements + requirements = project_requirements(config) + if requirements: + yield VenvCreate(context, message="installing requirements via uv") + LOG.debug("installing deps from %s with uv", requirements) + + # Equivalent to `pip install -U -r ` + reqs = [] + for requirement in requirements: + reqs.extend(["-r", str(requirement)]) await check_command( - [context.python_path, "-m", "pip", "install", "-U", "pip", "setuptools"] + [uv, "pip", "install", *reqs], + context=context, ) - pip = which("pip", context) - - # install requirements.txt - requirements = project_requirements(config) - if requirements: - yield VenvCreate(context, message="installing requirements") - LOG.debug("installing deps from %s", requirements) - cmd: List[StrPath] = [pip, "install", "-U"] - for requirement in requirements: - cmd.extend(["-r", requirement]) - await check_command(cmd) - - # install local project - yield VenvCreate(context, message="installing project") - if config.extras: - proj = f"{config.root}[{','.join(config.extras)}]" - else: - proj = str(config.root) - await check_command([pip, "install", "-U", proj]) - - # timestamp marker - content = f"{time.time_ns()}\n" - (context.venv / TIMESTAMP).write_text(content) + # Install local project + yield VenvCreate(context, message="installing project via uv") + if config.extras: + proj = f"{config.root}[{','.join(config.extras)}]" else: - LOG.debug("reusing existing virtualenv %s", context.venv) + proj = str(config.root) + await check_command([uv, "pip", "install", proj], context=context) + + # Record a timestamp + (context.venv / TIMESTAMP).write_text(f"{time.time_ns()}\n") yield VenvReady(context) @@ -258,6 +463,9 @@ async def prepare_virtualenv(context: Context, config: Config) -> AsyncIterator[ async def prepare_contexts( contexts: Sequence[Context], config: Config ) -> AsyncIterator[Event]: + """ + Prepare each context in parallel (as an async generator of events). + """ gens = [prepare_virtualenv(context, config) for context in contexts] async for event in as_generated(gens): yield event diff --git a/thx/runner.py b/thx/runner.py index 8b22234..31086aa 100644 --- a/thx/runner.py +++ b/thx/runner.py @@ -39,6 +39,7 @@ async def run_command( if context: new_env = os.environ.copy() new_env["PATH"] = f"{venv_bin_path(context.venv)}{os.pathsep}{new_env['PATH']}" + new_env["VIRTUAL_ENV"] = str(context.venv) proc = await asyncio.create_subprocess_exec( *cmd, stdout=PIPE, stderr=PIPE, env=new_env ) @@ -51,8 +52,10 @@ async def run_command( ) -async def check_command(command: Sequence[StrPath]) -> CommandResult: - result = await run_command(command) +async def check_command( + command: Sequence[StrPath], context: Optional[Context] = None +) -> CommandResult: + result = await run_command(command, context) if result.error: raise CommandError(command, result) diff --git a/thx/tests/context.py b/thx/tests/context.py index d91a330..5c89583 100644 --- a/thx/tests/context.py +++ b/thx/tests/context.py @@ -4,6 +4,7 @@ import asyncio import platform import subprocess +import sys from pathlib import Path from tempfile import TemporaryDirectory from typing import AsyncIterator, List, Optional, Sequence, Tuple @@ -14,6 +15,7 @@ from .. import context from ..types import ( + Builder, CommandResult, Config, Context, @@ -249,14 +251,15 @@ def test_find_runtime_venv(self, runtime_mock: Mock, which_mock: Mock) -> None: def test_resolve_contexts_no_config(self, runtime_mock: Mock) -> None: with TemporaryDirectory() as td: tdp = Path(td).resolve() - config = Config(root=tdp) + config = Config(root=tdp, builder=Builder.PIP) active_version = Version(platform.python_version()) expected = [ Context( active_version, - Path(""), + Path(sys.executable), context.venv_path(config, active_version), live=True, + builder=Builder.PIP, ) ] result = context.resolve_contexts(config, Options()) @@ -270,7 +273,7 @@ def test_resolve_contexts_multiple_versions( ) -> None: with TemporaryDirectory() as td: tdp = Path(td).resolve() - config = Config(root=tdp, versions=TEST_VERSIONS) + config = Config(root=tdp, versions=TEST_VERSIONS, builder=Builder.PIP) expected_venvs = { version: context.venv_path(config, version) for version in TEST_VERSIONS @@ -366,9 +369,10 @@ async def test_needs_update(self) -> None: @patch("thx.context.check_command") @patch("thx.context.which") + @patch("thx.context.identify_venv") @async_test async def test_prepare_virtualenv_extras( - self, which_mock: Mock, run_mock: Mock + self, identity_mock: Mock, which_mock: Mock, run_mock: Mock ) -> None: self.maxDiff = None @@ -383,6 +387,8 @@ async def fake_check_command(cmd: Sequence[StrPath]) -> CommandResult: venv = tdp / ".thx" / "venv" / "3.9" venv.mkdir(parents=True) + identity_mock.return_value = (Version("3.9.21"), venv / "bin/python3.9") + config = Config(root=tdp, extras=["more"]) ctx = Context(Version("3.9"), venv / "bin" / "python", venv) pip = which_mock("pip", ctx) @@ -421,7 +427,9 @@ async def fake_check_command(cmd: Sequence[StrPath]) -> CommandResult: async def test_prepare_virtualenv_live( self, which_mock: Mock, run_mock: Mock ) -> None: - async def fake_check_command(cmd: Sequence[StrPath]) -> CommandResult: + async def fake_check_command( + cmd: Sequence[StrPath], context: Optional[Context] = None + ) -> CommandResult: return CommandResult(0, "", "") run_mock.side_effect = fake_check_command @@ -432,7 +440,7 @@ async def fake_check_command(cmd: Sequence[StrPath]) -> CommandResult: reqs = tdp / "requirements.txt" reqs.write_text("\n") - config = Config(root=tdp) + config = Config(root=tdp, builder=Builder.PIP) ctx = context.resolve_contexts(config, Options(live=True))[0] self.assertTrue(ctx.live) diff --git a/thx/types.py b/thx/types.py index 16083bd..b8de7df 100644 --- a/thx/types.py +++ b/thx/types.py @@ -2,6 +2,7 @@ # Licensed under the MIT License from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from shlex import quote from typing import ( @@ -66,6 +67,12 @@ def __post_init__(self) -> None: self.requires = tuple(r.casefold() for r in self.requires) +class Builder(Enum): + PIP = "pip" + UV = "uv" + AUTO = "auto" + + @dataclass class Config: root: Path = field(default_factory=Path.cwd) @@ -76,6 +83,7 @@ class Config: requirements: Sequence[str] = field(default_factory=list) extras: Sequence[str] = field(default_factory=list) watch_paths: Set[Path] = field(default_factory=set) + builder: Builder = Builder.AUTO def __post_init__(self) -> None: self.default = tuple(d.casefold() for d in self.default) @@ -84,8 +92,9 @@ def __post_init__(self) -> None: @dataclass(unsafe_hash=True) class Context: python_version: Version - python_path: Path + python_path: Optional[Path] venv: Path + builder: Builder = Builder.PIP live: bool = False diff --git a/thx/utils.py b/thx/utils.py index 5b16b30..e9e6fb7 100644 --- a/thx/utils.py +++ b/thx/utils.py @@ -10,7 +10,7 @@ from itertools import zip_longest from pathlib import Path from time import monotonic_ns -from typing import Any, Callable, List, Optional, TypeVar +from typing import Any, Callable, Iterable, List, Optional, TypeVar from typing_extensions import ParamSpec @@ -130,7 +130,7 @@ def which(name: str, context: Context) -> str: return binary -def version_match(versions: List[Version], target: Version) -> List[Version]: +def version_match(versions: Iterable[Version], target: Version) -> List[Version]: matches: List[Version] = [] for version in versions: if all(