From 5d00f19fa0a0aa549014e1592575e0abf52b874f Mon Sep 17 00:00:00 2001 From: Christoph Lassner Date: Sun, 25 Feb 2024 18:01:51 -0800 Subject: [PATCH] Fix #21 and add export functionality. --- pdm-plugin-torch/pdm_plugin_torch/config.py | 54 +++++-- pdm-plugin-torch/pdm_plugin_torch/main.py | 166 +++++++++++++++----- 2 files changed, 172 insertions(+), 48 deletions(-) diff --git a/pdm-plugin-torch/pdm_plugin_torch/config.py b/pdm-plugin-torch/pdm_plugin_torch/config.py index a5c00f3..095d61f 100644 --- a/pdm-plugin-torch/pdm_plugin_torch/config.py +++ b/pdm-plugin-torch/pdm_plugin_torch/config.py @@ -1,3 +1,4 @@ +"""Plugin configuration.""" from __future__ import annotations from dataclasses import dataclass, field @@ -5,41 +6,72 @@ @dataclass(frozen=True) class Configuration: + """ + Plugin configuration. + + Attributes: + dependencies: list of top level dependencies. + enable_cpu: CPU feature flag. + enable_cuda: CUDA feature flag. + enable_rocm: ROCm feature flag. + cuda_versions: list of versions for CUDA to support. + rocm_versions: list of ROCm versions to support. + lockfile: path to the lock file to use. + """ + + # Dependency list. dependencies: list[str] + # Feature flags. enable_cpu: bool = False - enable_cuda: bool = False - cuda_versions: list[str] = field(default_factory=list) - enable_rocm: bool = False + # Version identifiers for the different possible versioned dependencies. + cuda_versions: list[str] = field(default_factory=list) rocm_versions: list[str] = field(default_factory=list) - + # Lockfile configuration. lockfile: str = "torch.lock" + @staticmethod def from_toml(data: dict[str, str | list[str] | bool]) -> "Configuration": + """ + Create a configuration object from a pyproject.toml configuration file. + + Args: + data: parsed TOML of the pyproject file. + + Returns: + Configuration object. + """ fixed_dashes = {k.replace("-", "_"): v for (k, v) in data.items()} return Configuration(**fixed_dashes) @property - def variants(self): - resolves = {} + def variants(self) -> dict[str, tuple[str, str]]: + """ + Get resolution URL and build identifier for all configured variants for the plugin. + Returns: + A dictionary of torch build alternatives to a tuple of + (resolution URL, build identifier). + """ + resolves = {} + if self.enable_cpu: + # We can omit the build identifier for the CPU only versions + # since the resolution at the CPU URL works correctly for all + # versions only without a tag (see the MacOS builds at + # https://download.pytorch.org/whl/cpu). + resolves["cpu"] = ("https://download.pytorch.org/whl/cpu", "") if self.enable_cuda: for cuda_version in self.cuda_versions: resolves[cuda_version] = ( f"https://download.pytorch.org/whl/{cuda_version}/", f"+{cuda_version}", ) - if self.enable_rocm: for rocm_version in self.rocm_versions: resolves[f"rocm{rocm_version}"] = ( "https://download.pytorch.org/whl/", f"+rocm{rocm_version}", ) - - if self.enable_cpu: - resolves["cpu"] = ("https://download.pytorch.org/whl/cpu", "+cpu") - return resolves diff --git a/pdm-plugin-torch/pdm_plugin_torch/main.py b/pdm-plugin-torch/pdm_plugin_torch/main.py index a88f677..1a652fc 100644 --- a/pdm-plugin-torch/pdm_plugin_torch/main.py +++ b/pdm-plugin-torch/pdm_plugin_torch/main.py @@ -1,42 +1,48 @@ +"""Plugin main file.""" from __future__ import annotations import sys - +from pathlib import Path from typing import Iterable import tomlkit - from pdm import __version__, termui from pdm._types import RepositoryConfig from pdm.cli.commands.base import BaseCommand +from pdm.cli.commands.export import Command as PDMExportCommand from pdm.cli.utils import fetch_hashes, format_lockfile, format_resolution_impossible from pdm.core import Core +from pdm.formats import FORMATS from pdm.models.candidates import Candidate from pdm.models.repositories import BaseRepository, LockedRepository -from pdm.models.requirements import Requirement, parse_requirement +from pdm.models.requirements import Requirement, parse_requirement, strip_extras from pdm.models.specifiers import PySpecSet, get_specifier from pdm.project import Project from pdm.resolver import resolve -from pdm.resolver.providers import BaseProvider +from pdm.resolver.providers import ( + BaseProvider, + EagerUpdateProvider, + ReusePinProvider, +) from pdm.termui import Verbosity -from pdm.utils import atomic_open_for_write, expand_env_vars_in_auth +from pdm.utils import atomic_open_for_write, expand_env_vars_in_auth, normalize_name from resolvelib.reporters import BaseReporter from resolvelib.resolvers import ResolutionImpossible, ResolutionTooDeep, Resolver from pdm_plugin_torch.config import Configuration - +is_pdm212 = PySpecSet(">=2.12").contains(__version__.__version__) is_pdm210 = PySpecSet(">=2.10").contains(__version__.__version__) is_pdm29 = PySpecSet(">=2.9").contains(__version__.__version__) is_pdm28 = PySpecSet(">=2.8").contains(__version__.__version__) -def sources(project: Project, sources: list) -> list[RepositoryConfig]: +def sources(project: Project, source_list: list) -> list[RepositoryConfig]: result: dict[str, RepositoryConfig] = {} for source in project.pyproject.settings.get("source", []): result[source["name"]] = RepositoryConfig(**source, config_prefix="pypi") - for source in sources: + for source in source_list: result[source["name"]] = RepositoryConfig(**source, config_prefix="torch") def merge_sources(other_sources: Iterable[tuple[str, RepositoryConfig]]) -> None: @@ -71,46 +77,46 @@ def get_provider( tracked_names: Iterable[str] | None = None, allow_prereleases: bool = False, ) -> BaseProvider: - """Build a provider class for resolver. + """ + Build a provider class for resolver. + :param strategy: the resolve strategy :param tracked_names: the names of packages that needs to update :param for_install: if the provider is for install :returns: The provider object """ - from pdm.models.requirements import strip_extras - from pdm.resolver.providers import ( - BaseProvider, - EagerUpdateProvider, - ReusePinProvider, - ) - from pdm.utils import normalize_name repository = get_repository( project, raw_sources, for_install=for_install, lockfile=lockfile ) - overrides = { - normalize_name(k): v for k, v in project.pyproject.resolution_overrides.items() - } + if not is_pdm212: + overrides = { + normalize_name(k): v + for k, v in project.pyproject.resolution_overrides.items() + } locked_repository: LockedRepository | None = None if strategy != "all" or for_install: try: locked_repository = LockedRepository(lockfile, sources, project.environment) - except Exception: + except Exception as ex: # pylint: disable=W0718 if for_install: - raise + raise ex project.core.ui.echo( "Unable to reuse the lock file as it is not compatible with PDM", style="warning", err=True, ) + if is_pdm212: + additional_base_provider_args = [] + else: + additional_base_provider_args = [allow_prereleases, overrides] if locked_repository is None: - return BaseProvider(repository, allow_prereleases, overrides) - + return BaseProvider(repository, *additional_base_provider_args) if for_install: - return BaseProvider(locked_repository, allow_prereleases, overrides) + return BaseProvider(locked_repository, *additional_base_provider_args) provider_class = ReusePinProvider if strategy == "reuse" else EagerUpdateProvider tracked_names = [strip_extras(name)[0] for name in tracked_names or ()] @@ -119,8 +125,7 @@ def get_provider( locked_repository.all_candidates, tracked_names, repository, - allow_prereleases, - overrides, + *additional_base_provider_args, ) @@ -128,10 +133,10 @@ def get_repository( project: Project, raw_sources: list, cls: type[BaseRepository] | None = None, - for_install: bool = False, - lockfile: dict = None, + for_install: bool = False, # pylint: disable=W0613 + lockfile: dict = None, # pylint: disable=W0613 ) -> BaseRepository: - """Get the repository object""" + """Get the repository object.""" if cls is None: cls = project.core.repository_class @@ -186,7 +191,9 @@ def do_lock( raise ResolutionImpossible("Unable to find a resolution") from None else: if is_pdm210: - from pdm.project.lockfile import FLAG_STATIC_URLS + from pdm.project.lockfile import ( # pylint: disable=C0415 + FLAG_STATIC_URLS, + ) data = format_lockfile( project, @@ -197,13 +204,13 @@ def do_lock( ) elif is_pdm29: - data = format_lockfile(project, mapping, dependencies, static_urls=True) + data = format_lockfile(project, mapping, dependencies, static_urls=True) # pylint: disable=E1123, E1120 elif is_pdm28: - data = format_lockfile(project, mapping, dependencies, static_urls=True) + data = format_lockfile(project, mapping, dependencies, static_urls=True) # pylint: disable=E1123, E1120 else: - data = format_lockfile(project, mapping, dependencies) + data = format_lockfile(project, mapping, dependencies) # pylint: disable=E1123, E1120 ui.echo(f"{termui.Emoji.LOCK} Lock successful") return data @@ -212,6 +219,14 @@ def do_lock( def write_lockfile( project: Project, lock_name: str, toml_data: dict, show_message: bool = True ) -> None: + """Write the lockfile for this project. + + Args: + project: the project to write the lockfile for. + lock_name: name of the lock file relative to the repository root. + toml_data: the data to write to the lock file. + show_message: whether to show a log message to the console. + """ toml_data["metadata"] = project.get_lock_metadata() lockfile_file = project.root / lock_name @@ -226,6 +241,7 @@ def resolve_candidates_from_lockfile( requirements: Iterable[Requirement], raw_sources, lockfile: dict, + for_install: bool = True, # pylint: disable=W0613 ) -> dict[str, Candidate]: ui = project.core.ui resolve_max_rounds = int(project.config["strategy.resolve_max_rounds"]) @@ -238,7 +254,7 @@ def resolve_candidates_from_lockfile( with ui.open_spinner("Resolving packages from lockfile...") as spinner: reporter = BaseReporter() provider = get_provider( - project, raw_sources, for_install=True, lockfile=lockfile + project, raw_sources, for_install=for_install, lockfile=lockfile ) resolver: Resolver = project.core.resolver_class(provider, reporter) mapping, *_ = resolve( @@ -372,7 +388,7 @@ def handle(self, project: Project, options: dict): (source, local_version) = resolves[options.api] if is_pdm210: - from pdm.project.lockfile import FLAG_STATIC_URLS + from pdm.project.lockfile import FLAG_STATIC_URLS # pylint: disable=C0415 class OverrideLockfile: def __init__(self, lockfile): @@ -390,7 +406,7 @@ def __getattr__(self, name): original_lockfile = project.lockfile - project._lockfile = OverrideLockfile(original_lockfile) + project._lockfile = OverrideLockfile(original_lockfile) # pylint: disable=W0212 reqs = [ parse_requirement(f"{req}{local_version}", False) @@ -412,7 +428,7 @@ def __getattr__(self, name): ) if is_pdm210: - project._lockfile = original_lockfile + project._lockfile = original_lockfile # pylint: disable=W0212 class LockCommand(BaseCommand): @@ -469,11 +485,87 @@ def handle(self, project: Project, options: dict): write_lockfile(project, plugin_config.lockfile, results) +class ExportCommand(BaseCommand): + name = "export" + description = "Export Torch and its dependencies for a specific version" + + def add_arguments(self, parser): + parser.add_argument("api", help="the api to use, e.g. cuda version or rocm") + PDMExportCommand.add_arguments(None, parser) + + def handle(self, project: Project, options: dict): + plugin_config = Configuration.from_toml(get_settings(project)) + resolves = plugin_config.variants + if options.api not in resolves: + raise ValueError( + f"unknown API {options.api}, expected one of {[v for v in resolves]}" + ) + + lockfile = read_lockfile(project, plugin_config.lockfile) + spec_for_version = lockfile[options.api] + + (source, local_version) = resolves[options.api] + + if is_pdm210: + from pdm.project.lockfile import FLAG_STATIC_URLS + + class OverrideLockfile: + def __init__(self, lockfile): + self._lockfile = lockfile + + @property + def strategy(self): + strategies = self._lockfile.strategy + strategies.add(FLAG_STATIC_URLS) + + return strategies + + def __getattr__(self, name): + return getattr(self._lockfile, name) + + original_lockfile = project.lockfile + project._lockfile = OverrideLockfile(original_lockfile) # pylint: disable=W0212 + + reqs = [ + parse_requirement(f"{req}{local_version}", False) + for req in plugin_config.dependencies + ] + + candidates = resolve_candidates_from_lockfile( + project, + reqs, + [ + { + "name": "torch", + "url": source, + "type": "index", + "verify_ssl": True, + } + ], + spec_for_version, + for_install=False, + ) + packages = ( + candidate for candidate in candidates.values() if not candidate.req.extras + ) + + if is_pdm210: + project._lockfile = original_lockfile # pylint: disable=W0212 + + content = FORMATS[options.format].export(project, packages, options) + if options.output: + Path(options.output).write_text(content, encoding="utf-8") + else: + # Use a regular print to avoid any formatting / wrapping. + print(content) + + class TorchCommand(BaseCommand): """Generate a lockfile for torch specifically.""" name = "torch" description = "Manage torch dependencies" + parser = None def add_arguments(self, parser): subparsers = parser.add_subparsers(