diff --git a/conda_lock/conda_lock.py b/conda_lock/conda_lock.py index 66ef8d34a..88d48c0cc 100644 --- a/conda_lock/conda_lock.py +++ b/conda_lock/conda_lock.py @@ -62,26 +62,20 @@ PIP_SUPPORT = True except ImportError: PIP_SUPPORT = False -from conda_lock.lookup import set_lookup_location -from conda_lock.src_parser import ( - Dependency, +from conda_lock.lockfile import ( GitMeta, InputMeta, LockedDependency, Lockfile, LockMeta, - LockSpecification, - MetadataOption, TimeMeta, UpdateSpecification, - aggregate_lock_specs, + parse_conda_lock_file, + write_conda_lock_file, ) -from conda_lock.src_parser.environment_yaml import parse_environment_file -from conda_lock.src_parser.lockfile import parse_conda_lock_file, write_conda_lock_file -from conda_lock.src_parser.meta_yaml import parse_meta_yaml_file -from conda_lock.src_parser.pyproject_toml import parse_pyproject_toml +from conda_lock.lookup import set_lookup_location +from conda_lock.src_parser import LockSpecification, MetadataOption, make_lock_spec from conda_lock.virtual_package import ( - FakeRepoData, default_virtual_package_repodata, virtual_package_repo_from_specification, ) @@ -114,8 +108,6 @@ sys.exit(1) -DEFAULT_PLATFORMS = ["osx-64", "linux-64", "win-64"] - KIND_EXPLICIT: Literal["explicit"] = "explicit" KIND_LOCK: Literal["lock"] = "lock" KIND_ENV: Literal["env"] = "env" @@ -243,44 +235,6 @@ def fn_to_dist_name(fn: str) -> str: return fn -def make_lock_spec( - *, - src_files: List[pathlib.Path], - virtual_package_repo: FakeRepoData, - channel_overrides: Optional[Sequence[str]] = None, - platform_overrides: Optional[Sequence[str]] = None, - required_categories: Optional[AbstractSet[str]] = None, -) -> LockSpecification: - """Generate the lockfile specs from a set of input src_files. If required_categories is set filter out specs that do not match those""" - lock_specs = parse_source_files( - src_files=src_files, platform_overrides=platform_overrides - ) - - lock_spec = aggregate_lock_specs(lock_specs) - lock_spec.virtual_package_repo = virtual_package_repo - lock_spec.channels = ( - [Channel.from_string(co) for co in channel_overrides] - if channel_overrides - else lock_spec.channels - ) - lock_spec.platforms = ( - list(platform_overrides) if platform_overrides else lock_spec.platforms - ) or list(DEFAULT_PLATFORMS) - - if required_categories is not None: - - def dep_has_category(d: Dependency, categories: AbstractSet[str]) -> bool: - return d.category in categories - - lock_spec.dependencies = [ - d - for d in lock_spec.dependencies - if dep_has_category(d, categories=required_categories) - ] - - return lock_spec - - def make_lock_files( *, conda: PathLike, @@ -353,11 +307,12 @@ def make_lock_files( with virtual_package_repo: lock_spec = make_lock_spec( - src_files=src_files, + src_file_paths=src_files, channel_overrides=channel_overrides, - platform_overrides=platform_overrides, + platform_overrides=set(platform_overrides) if platform_overrides else set(), virtual_package_repo=virtual_package_repo, required_categories=required_categories if filter_categories else None, + pip_support=PIP_SUPPORT, ) lock_content: Optional[Lockfile] = None @@ -704,12 +659,8 @@ def _solve_for_arch( """ if update_spec is None: update_spec = UpdateSpecification() - # filter requested and locked dependencies to the current platform - dependencies = [ - dep - for dep in spec.dependencies - if (not dep.selectors.platform) or platform in dep.selectors.platform - ] + dependencies = spec.dependencies[platform] + locked = [dep for dep in update_spec.locked if dep.platform == platform] requested_deps_by_name = { manager: {dep.name: dep for dep in dependencies if dep.manager == manager} @@ -867,42 +818,6 @@ def create_lockfile_from_spec( ) -def parse_source_files( - src_files: List[pathlib.Path], - platform_overrides: Optional[Sequence[str]], -) -> List[LockSpecification]: - """ - Parse a sequence of dependency specifications from source files - - Parameters - ---------- - src_files : - Files to parse for dependencies - platform_overrides : - Target platforms to render environment.yaml and meta.yaml files for - """ - desired_envs: List[LockSpecification] = [] - for src_file in src_files: - if src_file.name == "meta.yaml": - desired_envs.append( - parse_meta_yaml_file( - src_file, list(platform_overrides or DEFAULT_PLATFORMS) - ) - ) - elif src_file.name == "pyproject.toml": - desired_envs.append(parse_pyproject_toml(src_file)) - else: - desired_envs.append( - parse_environment_file( - src_file, - platform_overrides, - default_platforms=DEFAULT_PLATFORMS, - pip_support=PIP_SUPPORT, - ) - ) - return desired_envs - - def _add_auth_to_line(line: str, auth: Dict[str, str]) -> str: matching_auths = [a for a in auth if a in line] if not matching_auths: diff --git a/conda_lock/conda_solver.py b/conda_lock/conda_solver.py index c8903ef99..adfe12426 100644 --- a/conda_lock/conda_solver.py +++ b/conda_lock/conda_solver.py @@ -31,14 +31,9 @@ conda_pkgs_dir, is_micromamba, ) +from conda_lock.lockfile import HashModel, LockedDependency, _apply_categories from conda_lock.models.channel import Channel -from conda_lock.src_parser import ( - Dependency, - HashModel, - LockedDependency, - VersionedDependency, - _apply_categories, -) +from conda_lock.src_parser import Dependency, VersionedDependency logger = logging.getLogger(__name__) diff --git a/conda_lock/src_parser/lockfile.py b/conda_lock/lockfile/__init__.py similarity index 59% rename from conda_lock/src_parser/lockfile.py rename to conda_lock/lockfile/__init__.py index aa37be00e..a6d017364 100644 --- a/conda_lock/src_parser/lockfile.py +++ b/conda_lock/lockfile/__init__.py @@ -1,14 +1,79 @@ import json import pathlib +from collections import defaultdict from textwrap import dedent -from typing import Collection, Optional +from typing import Collection, Dict, List, Optional, Sequence, Set import yaml -from conda_lock.src_parser import MetadataOption +from conda_lock.src_parser import Dependency, MetadataOption -from . import Lockfile +from .models import * +from .models import LockedDependency, Lockfile + + +def _apply_categories( + requested: Dict[str, Dependency], + planned: Dict[str, LockedDependency], + categories: Sequence[str] = ("main", "dev"), +) -> None: + """map each package onto the root request the with the highest-priority category""" + # walk dependency tree to assemble all transitive dependencies by request + dependents: Dict[str, Set[str]] = {} + by_category = defaultdict(list) + + def seperator_munge_get( + d: Dict[str, LockedDependency], key: str + ) -> LockedDependency: + # since separators are not consistent across managers (or even within) we need to do some double attempts here + try: + return d[key] + except KeyError: + try: + return d[key.replace("-", "_")] + except KeyError: + return d[key.replace("_", "-")] + + for name, request in requested.items(): + todo: List[str] = list() + deps: Set[str] = set() + item = name + while True: + todo.extend( + dep + for dep in seperator_munge_get(planned, item).dependencies + # exclude virtual packages + if not (dep in deps or dep.startswith("__")) + ) + if todo: + item = todo.pop(0) + deps.add(item) + else: + break + + dependents[name] = deps + + by_category[request.category].append(request.name) + + # now, map each package to its root request + categories = [*categories, *(k for k in by_category if k not in categories)] + root_requests = {} + for category in categories: + for root in by_category.get(category, []): + for transitive_dep in dependents[root]: + if transitive_dep not in root_requests: + root_requests[transitive_dep] = root + # include root requests themselves + for name in requested: + root_requests[name] = name + + for dep, root in root_requests.items(): + source = requested[root] + # try a conda target first + target = seperator_munge_get(planned, dep) + target.category = source.category + target.optional = source.optional def parse_conda_lock_file( diff --git a/conda_lock/lockfile/models.py b/conda_lock/lockfile/models.py new file mode 100644 index 000000000..9a26cf438 --- /dev/null +++ b/conda_lock/lockfile/models.py @@ -0,0 +1,361 @@ +import datetime +import hashlib +import logging +import pathlib +import typing + +from collections import defaultdict, namedtuple +from typing import TYPE_CHECKING, AbstractSet, ClassVar, Dict, List, Optional, Union + + +if TYPE_CHECKING: + from hashlib import _Hash + +from pydantic import Field, validator +from typing_extensions import Literal + +from conda_lock.common import ordered_union, relative_path +from conda_lock.models import StrictModel +from conda_lock.models.channel import Channel +from conda_lock.src_parser import MetadataOption + + +logger = logging.getLogger(__name__) +LockKey = namedtuple("LockKey", ["manager", "name", "platform"]) + + +class DependencySource(StrictModel): + type: Literal["url"] + url: str + + +class HashModel(StrictModel): + md5: Optional[str] = None + sha256: Optional[str] = None + + +class TimeMeta(StrictModel): + """Stores information about when the lockfile was generated.""" + + created_at: str = Field(..., description="Time stamp of lock-file creation time") + + @classmethod + def create(cls) -> "TimeMeta": + return cls( + created_at=datetime.datetime.utcnow().isoformat(timespec="seconds") + "Z" + ) + + +class GitMeta(StrictModel): + """ + Stores information about the git repo the lockfile is being generated in (if applicable) and + the git user generating the file. + """ + + git_user_name: Optional[str] = Field( + default=None, description="Git user.name field of global config" + ) + git_user_email: Optional[str] = Field( + default=None, description="Git user.email field of global config" + ) + git_sha: Optional[str] = Field( + default=None, + description=( + "sha256 hash of the most recent git commit that modified one of the input files for " + + "this lockfile" + ), + ) + + @classmethod + def create( + cls, + metadata_choices: AbstractSet[MetadataOption], + src_files: List[pathlib.Path], + ) -> "GitMeta | None": + try: + import git + except ImportError: + return None + + git_sha: "str | None" = None + git_user_name: "str | None" = None + git_user_email: "str | None" = None + + try: + repo = git.Repo(search_parent_directories=True) # type: ignore + if MetadataOption.GitSha in metadata_choices: + most_recent_datetime: Optional[datetime.datetime] = None + for src_file in src_files: + relative_src_file_path = relative_path( + pathlib.Path(repo.working_tree_dir), src_file # type: ignore + ) + commit = list( + repo.iter_commits(paths=relative_src_file_path, max_count=1) + )[0] + if repo.is_dirty(path=relative_src_file_path): + logger.warning( + "One of the inputs to conda-lock is dirty, using commit hash of head +" + ' "dirty"' + ) + git_sha = f"{repo.head.object.hexsha}-dirty" + break + else: + if ( + most_recent_datetime is None + or most_recent_datetime < commit.committed_datetime + ): + most_recent_datetime = commit.committed_datetime + git_sha = commit.hexsha + if MetadataOption.GitUserName in metadata_choices: + git_user_name = repo.config_reader().get_value("user", "name", None) # type: ignore + if MetadataOption.GitUserEmail in metadata_choices: + git_user_email = repo.config_reader().get_value("user", "email", None) # type: ignore + except git.exc.InvalidGitRepositoryError: # type: ignore + pass + + if any([git_sha, git_user_name, git_user_email]): + return cls( + git_sha=git_sha, + git_user_name=git_user_name, + git_user_email=git_user_email, + ) + else: + return None + + +class InputMeta(StrictModel): + """Stores information about an input provided to generate the lockfile.""" + + md5: Optional[str] = Field(..., description="md5 checksum for an input file") + sha256: Optional[str] = Field(..., description="md5 checksum for an input file") + + @classmethod + def create( + cls, metadata_choices: AbstractSet[MetadataOption], src_file: pathlib.Path + ) -> "InputMeta": + if MetadataOption.InputSha in metadata_choices: + sha256 = cls.get_input_sha256(src_file=src_file) + else: + sha256 = None + if MetadataOption.InputMd5 in metadata_choices: + md5 = cls.get_input_md5(src_file=src_file) + else: + md5 = None + return cls( + md5=md5, + sha256=sha256, + ) + + @classmethod + def get_input_md5(cls, src_file: pathlib.Path) -> str: + hasher = hashlib.md5() + return cls.hash_file(src_file=src_file, hasher=hasher) + + @classmethod + def get_input_sha256(cls, src_file: pathlib.Path) -> str: + hasher = hashlib.sha256() + return cls.hash_file(src_file=src_file, hasher=hasher) + + @staticmethod + def hash_file(src_file: pathlib.Path, hasher: "_Hash") -> str: + with src_file.open("r") as infile: + hasher.update(infile.read().encode("utf-8")) + return hasher.hexdigest() + + +class LockMeta(StrictModel): + content_hash: Dict[str, str] = Field( + ..., description="Hash of dependencies for each target platform" + ) + channels: List[Channel] = Field( + ..., description="Channels used to resolve dependencies" + ) + platforms: List[str] = Field(..., description="Target platforms") + sources: List[str] = Field( + ..., + description="paths to source files, relative to the parent directory of the lockfile", + ) + time_metadata: Optional[TimeMeta] = Field( + default=None, description="Metadata dealing with the time lockfile was created" + ) + git_metadata: Optional[GitMeta] = Field( + default=None, + description=( + "Metadata dealing with the git repo the lockfile was created in and the user that created it" + ), + ) + inputs_metadata: Optional[Dict[str, InputMeta]] = Field( + default=None, + description="Metadata dealing with the input files used to create the lockfile", + ) + custom_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="Custom metadata provided by the user to be added to the lockfile", + ) + + def __or__(self, other: "LockMeta") -> "LockMeta": + """merge other into self""" + if other is None: + return self + elif not isinstance(other, LockMeta): + raise TypeError + + if self.inputs_metadata is None: + new_inputs_metadata = other.inputs_metadata + elif other.inputs_metadata is None: + new_inputs_metadata = self.inputs_metadata + else: + new_inputs_metadata = self.inputs_metadata + new_inputs_metadata.update(other.inputs_metadata) + + if self.custom_metadata is None: + new_custom_metadata = other.custom_metadata + elif other.custom_metadata is None: + new_custom_metadata = self.custom_metadata + else: + new_custom_metadata = self.custom_metadata + for key in other.custom_metadata: + if key in new_custom_metadata: + logger.warning( + f"Custom metadata key {key} provided twice, overwriting original value" + + f"({new_custom_metadata[key]}) with new value " + + f"({other.custom_metadata[key]})" + ) + new_custom_metadata.update(other.custom_metadata) + return LockMeta( + content_hash={**self.content_hash, **other.content_hash}, + channels=self.channels, + platforms=sorted(set(self.platforms).union(other.platforms)), + sources=ordered_union([self.sources, other.sources]), + time_metadata=other.time_metadata, + git_metadata=other.git_metadata, + inputs_metadata=new_inputs_metadata, + custom_metadata=new_custom_metadata, + ) + + @validator("channels", pre=True, always=True) + def ensure_channels(cls, v: List[Union[str, Channel]]) -> List[Channel]: + res = [] + for e in v: + if isinstance(e, str): + res.append(Channel.from_string(e)) + else: + res.append(e) + return typing.cast(List[Channel], res) + + +class LockedDependency(StrictModel): + name: str + version: str + manager: Literal["conda", "pip"] + platform: str + dependencies: Dict[str, str] = {} + url: str + hash: HashModel + optional: bool = False + category: str = "main" + source: Optional[DependencySource] = None + build: Optional[str] = None + + def key(self) -> LockKey: + return LockKey(self.manager, self.name, self.platform) + + @validator("hash") + def validate_hash(cls, v: HashModel, values: Dict[str, typing.Any]) -> HashModel: + if (values["manager"] == "conda") and (v.md5 is None): + raise ValueError("conda package hashes must use MD5") + return v + + +class Lockfile(StrictModel): + version: ClassVar[int] = 1 + + package: List[LockedDependency] + metadata: LockMeta + + def __or__(self, other: "Lockfile") -> "Lockfile": + return other.__ror__(self) + + def __ror__(self, other: "Optional[Lockfile]") -> "Lockfile": + """ + merge self into other + """ + if other is None: + return self + elif not isinstance(other, Lockfile): + raise TypeError + + assert self.metadata.channels == other.metadata.channels + + ours = {d.key(): d for d in self.package} + theirs = {d.key(): d for d in other.package} + + # Pick ours preferentially + package: List[LockedDependency] = [] + for key in sorted(set(ours.keys()).union(theirs.keys())): + if key not in ours or key[-1] not in self.metadata.platforms: + package.append(theirs[key]) + else: + package.append(ours[key]) + + # Resort the conda packages topologically + final_package = self._toposort(package) + return Lockfile(package=final_package, metadata=other.metadata | self.metadata) + + def toposort_inplace(self) -> None: + self.package = self._toposort(self.package) + + @staticmethod + def _toposort( + package: List[LockedDependency], update: bool = False + ) -> List[LockedDependency]: + platforms = {d.platform for d in package} + + # Resort the conda packages topologically + final_package: List[LockedDependency] = [] + for platform in sorted(platforms): + from .._vendor.conda.common.toposort import toposort + + # Add the remaining non-conda packages in the order in which they appeared. + # Order the pip packages topologically ordered (might be not 100% perfect if they depend on + # other conda packages, but good enough + for manager in ["conda", "pip"]: + lookup = defaultdict(set) + packages: Dict[str, LockedDependency] = {} + + for d in package: + if d.platform != platform: + continue + + if d.manager != manager: + continue + + lookup[d.name] = set(d.dependencies) + packages[d.name] = d + + ordered = toposort(lookup) + for package_name in ordered: + # since we could have a pure dep in here, that does not have a package + # eg a pip package that depends on a conda package (the conda package will not be in this list) + dep = packages.get(package_name) + if dep is None: + continue + if dep.manager != manager: + continue + # skip virtual packages + if dep.manager == "conda" and dep.name.startswith("__"): + continue + + final_package.append(dep) + + return final_package + + +class UpdateSpecification: + def __init__( + self, + locked: Optional[List[LockedDependency]] = None, + update: Optional[List[str]] = None, + ): + self.locked = locked or [] + self.update = update or [] diff --git a/conda_lock/models/__init__.py b/conda_lock/models/__init__.py index e69de29bb..20251e462 100644 --- a/conda_lock/models/__init__.py +++ b/conda_lock/models/__init__.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class StrictModel(BaseModel): + class Config: + extra = "forbid" + json_encoders = { + frozenset: list, + } diff --git a/conda_lock/pypi_solver.py b/conda_lock/pypi_solver.py index ca0cd27a3..2df485be1 100644 --- a/conda_lock/pypi_solver.py +++ b/conda_lock/pypi_solver.py @@ -10,7 +10,7 @@ from clikit.io import ConsoleIO, NullIO from packaging.tags import compatible_tags, cpython_tags -from conda_lock import src_parser +from conda_lock import lockfile, src_parser from conda_lock._vendor.poetry.core.packages import Dependency as PoetryDependency from conda_lock._vendor.poetry.core.packages import Package as PoetryPackage from conda_lock._vendor.poetry.core.packages import ( @@ -159,7 +159,7 @@ def get_dependency(dep: src_parser.Dependency) -> PoetryDependency: raise ValueError(f"Unknown requirement {dep}") -def get_package(locked: src_parser.LockedDependency) -> PoetryPackage: +def get_package(locked: lockfile.LockedDependency) -> PoetryPackage: if locked.source is not None: return PoetryPackage( locked.name, @@ -174,12 +174,12 @@ def get_package(locked: src_parser.LockedDependency) -> PoetryPackage: def solve_pypi( pip_specs: Dict[str, src_parser.Dependency], use_latest: List[str], - pip_locked: Dict[str, src_parser.LockedDependency], - conda_locked: Dict[str, src_parser.LockedDependency], + pip_locked: Dict[str, lockfile.LockedDependency], + conda_locked: Dict[str, lockfile.LockedDependency], python_version: str, platform: str, verbose: bool = False, -) -> Dict[str, src_parser.LockedDependency]: +) -> Dict[str, lockfile.LockedDependency]: """ Solve pip dependencies for the given platform @@ -226,7 +226,7 @@ def solve_pypi( locked = Repository() python_packages = dict() - locked_dep: src_parser.LockedDependency + locked_dep: lockfile.LockedDependency for locked_dep in conda_locked.values(): if locked_dep.name.startswith("__"): continue @@ -273,16 +273,16 @@ def solve_pypi( # Extract distributions from Poetry package plan, ignoring uninstalls # (usually: conda package with no pypi equivalent) and skipped ops # (already installed) - requirements: List[src_parser.LockedDependency] = [] + requirements: List[lockfile.LockedDependency] = [] for op in result: if not isinstance(op, Uninstall) and not op.skipped: # Take direct references verbatim - source: Optional[src_parser.DependencySource] = None + source: Optional[lockfile.DependencySource] = None if op.package.source_type == "url": url, fragment = urldefrag(op.package.source_url) hash_type, hash = fragment.split("=") - hash = src_parser.HashModel(**{hash_type: hash}) - source = src_parser.DependencySource( + hash = lockfile.HashModel(**{hash_type: hash}) + source = lockfile.DependencySource( type="url", url=op.package.source_url ) # Choose the most specific distribution for the target @@ -292,10 +292,10 @@ def solve_pypi( hashes: Dict[str, str] = {} if link.hash_name is not None and link.hash is not None: hashes[link.hash_name] = link.hash - hash = src_parser.HashModel.parse_obj(hashes) + hash = lockfile.HashModel.parse_obj(hashes) requirements.append( - src_parser.LockedDependency( + lockfile.LockedDependency( name=op.package.name, version=str(op.package.version), manager="pip", @@ -324,6 +324,6 @@ def solve_pypi( continue planned[pypi_name] = locked_dep - src_parser._apply_categories(requested=pip_specs, planned=planned) + lockfile._apply_categories(requested=pip_specs, planned=planned) return {dep.name: dep for dep in requirements} diff --git a/conda_lock/src_parser/__init__.py b/conda_lock/src_parser/__init__.py index 8c9965e65..3ff9d9b98 100644 --- a/conda_lock/src_parser/__init__.py +++ b/conda_lock/src_parser/__init__.py @@ -1,4 +1,3 @@ -import datetime import enum import hashlib import json @@ -6,13 +5,11 @@ import pathlib import typing -from collections import defaultdict, namedtuple from itertools import chain from typing import ( - TYPE_CHECKING, AbstractSet, - ClassVar, Dict, + Iterable, List, Optional, Sequence, @@ -21,28 +18,19 @@ Union, ) - -if TYPE_CHECKING: - from hashlib import _Hash - -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, validator from typing_extensions import Literal -from conda_lock.common import ordered_union, relative_path, suffix_union +from conda_lock.common import suffix_union from conda_lock.errors import ChannelAggregationError +from conda_lock.models import StrictModel from conda_lock.models.channel import Channel from conda_lock.virtual_package import FakeRepoData -logger = logging.getLogger(__name__) - +DEFAULT_PLATFORMS = {"osx-64", "linux-64", "win-64"} -class StrictModel(BaseModel): - class Config: - extra = "forbid" - json_encoders = { - frozenset: list, - } +logger = logging.getLogger(__name__) class Selectors(StrictModel): @@ -67,7 +55,9 @@ class _BaseDependency(StrictModel): optional: bool = False category: str = "main" extras: List[str] = [] - selectors: Selectors = Selectors() + + def to_source(self) -> "SourceDependency": + return SourceDependency(dep=self) # type: ignore class VersionedDependency(_BaseDependency): @@ -84,45 +74,14 @@ class URLDependency(_BaseDependency): Dependency = Union[VersionedDependency, URLDependency] -class Package(StrictModel): - url: str - hash: str - - -class DependencySource(StrictModel): - type: Literal["url"] - url: str - - -LockKey = namedtuple("LockKey", ["manager", "name", "platform"]) - - -class HashModel(StrictModel): - md5: Optional[str] = None - sha256: Optional[str] = None +class SourceDependency(StrictModel): + dep: Dependency + selectors: Selectors = Selectors() -class LockedDependency(StrictModel): - name: str - version: str - manager: Literal["conda", "pip"] - platform: str - dependencies: Dict[str, str] = {} +class Package(StrictModel): url: str - hash: HashModel - optional: bool = False - category: str = "main" - source: Optional[DependencySource] = None - build: Optional[str] = None - - def key(self) -> LockKey: - return LockKey(self.manager, self.name, self.platform) - - @validator("hash") - def validate_hash(cls, v: HashModel, values: Dict[str, typing.Any]) -> HashModel: - if (values["manager"] == "conda") and (v.md5 is None): - raise ValueError("conda package hashes must use MD5") - return v + hash: str class MetadataOption(enum.Enum): @@ -134,313 +93,46 @@ class MetadataOption(enum.Enum): InputSha = "input_sha" -class TimeMeta(StrictModel): - """Stores information about when the lockfile was generated.""" - - created_at: str = Field(..., description="Time stamp of lock-file creation time") - - @classmethod - def create(cls) -> "TimeMeta": - return cls( - created_at=datetime.datetime.utcnow().isoformat(timespec="seconds") + "Z" - ) - - -class GitMeta(StrictModel): - """ - Stores information about the git repo the lockfile is being generated in (if applicable) and - the git user generating the file. - """ - - git_user_name: Optional[str] = Field( - default=None, description="Git user.name field of global config" - ) - git_user_email: Optional[str] = Field( - default=None, description="Git user.email field of global config" - ) - git_sha: Optional[str] = Field( - default=None, - description=( - "sha256 hash of the most recent git commit that modified one of the input files for " - + "this lockfile" - ), - ) - - @classmethod - def create( - cls, - metadata_choices: AbstractSet[MetadataOption], - src_files: List[pathlib.Path], - ) -> "GitMeta | None": - try: - import git - except ImportError: - return None - - git_sha: "str | None" = None - git_user_name: "str | None" = None - git_user_email: "str | None" = None - - try: - repo = git.Repo(search_parent_directories=True) # type: ignore - if MetadataOption.GitSha in metadata_choices: - most_recent_datetime: Optional[datetime.datetime] = None - for src_file in src_files: - relative_src_file_path = relative_path( - pathlib.Path(repo.working_tree_dir), src_file # type: ignore - ) - commit = list( - repo.iter_commits(paths=relative_src_file_path, max_count=1) - )[0] - if repo.is_dirty(path=relative_src_file_path): - logger.warning( - "One of the inputs to conda-lock is dirty, using commit hash of head +" - ' "dirty"' - ) - git_sha = f"{repo.head.object.hexsha}-dirty" - break - else: - if ( - most_recent_datetime is None - or most_recent_datetime < commit.committed_datetime - ): - most_recent_datetime = commit.committed_datetime - git_sha = commit.hexsha - if MetadataOption.GitUserName in metadata_choices: - git_user_name = repo.config_reader().get_value("user", "name", None) # type: ignore - if MetadataOption.GitUserEmail in metadata_choices: - git_user_email = repo.config_reader().get_value("user", "email", None) # type: ignore - except git.exc.InvalidGitRepositoryError: # type: ignore - pass - - if any([git_sha, git_user_name, git_user_email]): - return cls( - git_sha=git_sha, - git_user_name=git_user_name, - git_user_email=git_user_email, - ) - else: - return None - - -class InputMeta(StrictModel): - """Stores information about an input provided to generate the lockfile.""" - - md5: Optional[str] = Field(..., description="md5 checksum for an input file") - sha256: Optional[str] = Field(..., description="md5 checksum for an input file") - - @classmethod - def create( - cls, metadata_choices: AbstractSet[MetadataOption], src_file: pathlib.Path - ) -> "InputMeta": - if MetadataOption.InputSha in metadata_choices: - sha256 = cls.get_input_sha256(src_file=src_file) - else: - sha256 = None - if MetadataOption.InputMd5 in metadata_choices: - md5 = cls.get_input_md5(src_file=src_file) - else: - md5 = None - return cls( - md5=md5, - sha256=sha256, - ) - - @classmethod - def get_input_md5(cls, src_file: pathlib.Path) -> str: - hasher = hashlib.md5() - return cls.hash_file(src_file=src_file, hasher=hasher) - - @classmethod - def get_input_sha256(cls, src_file: pathlib.Path) -> str: - hasher = hashlib.sha256() - return cls.hash_file(src_file=src_file, hasher=hasher) - - @staticmethod - def hash_file(src_file: pathlib.Path, hasher: "_Hash") -> str: - with src_file.open("r") as infile: - hasher.update(infile.read().encode("utf-8")) - return hasher.hexdigest() - - -class LockMeta(StrictModel): - content_hash: Dict[str, str] = Field( - ..., description="Hash of dependencies for each target platform" - ) - channels: List[Channel] = Field( - ..., description="Channels used to resolve dependencies" - ) - platforms: List[str] = Field(..., description="Target platforms") - sources: List[str] = Field( - ..., - description="paths to source files, relative to the parent directory of the lockfile", - ) - time_metadata: Optional[TimeMeta] = Field( - default=None, description="Metadata dealing with the time lockfile was created" - ) - git_metadata: Optional[GitMeta] = Field( - default=None, - description=( - "Metadata dealing with the git repo the lockfile was created in and the user that created it" - ), - ) - inputs_metadata: Optional[Dict[str, InputMeta]] = Field( - default=None, - description="Metadata dealing with the input files used to create the lockfile", - ) - custom_metadata: Optional[Dict[str, str]] = Field( - default=None, - description="Custom metadata provided by the user to be added to the lockfile", - ) - - def __or__(self, other: "LockMeta") -> "LockMeta": - """merge other into self""" - if other is None: - return self - elif not isinstance(other, LockMeta): - raise TypeError - - if self.inputs_metadata is None: - new_inputs_metadata = other.inputs_metadata - elif other.inputs_metadata is None: - new_inputs_metadata = self.inputs_metadata - else: - new_inputs_metadata = self.inputs_metadata - new_inputs_metadata.update(other.inputs_metadata) +class SourceFile(StrictModel): + file: pathlib.Path + dependencies: List[SourceDependency] + # TODO: Should we store the auth info in here? + channels: List[Channel] + platforms: Set[str] - if self.custom_metadata is None: - new_custom_metadata = other.custom_metadata - elif other.custom_metadata is None: - new_custom_metadata = self.custom_metadata - else: - new_custom_metadata = self.custom_metadata - for key in other.custom_metadata: - if key in new_custom_metadata: - logger.warning( - f"Custom metadata key {key} provided twice, overwriting original value" - + f"({new_custom_metadata[key]}) with new value " - + f"({other.custom_metadata[key]})" - ) - new_custom_metadata.update(other.custom_metadata) - return LockMeta( - content_hash={**self.content_hash, **other.content_hash}, - channels=self.channels, - platforms=sorted(set(self.platforms).union(other.platforms)), - sources=ordered_union([self.sources, other.sources]), - time_metadata=other.time_metadata, - git_metadata=other.git_metadata, - inputs_metadata=new_inputs_metadata, - custom_metadata=new_custom_metadata, - ) - - @validator("channels", pre=True, always=True) - def ensure_channels(cls, v: List[Union[str, Channel]]) -> List[Channel]: - res = [] - for e in v: + @validator("channels", pre=True) + def validate_channels(cls, v: List[Union[Channel, str]]) -> List[Channel]: + for i, e in enumerate(v): if isinstance(e, str): - res.append(Channel.from_string(e)) - else: - res.append(e) - return typing.cast(List[Channel], res) - - -class Lockfile(StrictModel): - - version: ClassVar[int] = 1 - - package: List[LockedDependency] - metadata: LockMeta - - def __or__(self, other: "Lockfile") -> "Lockfile": - return other.__ror__(self) - - def __ror__(self, other: "Optional[Lockfile]") -> "Lockfile": - """ - merge self into other - """ - if other is None: - return self - elif not isinstance(other, Lockfile): - raise TypeError - - assert self.metadata.channels == other.metadata.channels - - ours = {d.key(): d for d in self.package} - theirs = {d.key(): d for d in other.package} - - # Pick ours preferentially - package: List[LockedDependency] = [] - for key in sorted(set(ours.keys()).union(theirs.keys())): - if key not in ours or key[-1] not in self.metadata.platforms: - package.append(theirs[key]) - else: - package.append(ours[key]) - - # Resort the conda packages topologically - final_package = self._toposort(package) - return Lockfile(package=final_package, metadata=other.metadata | self.metadata) - - def toposort_inplace(self) -> None: - self.package = self._toposort(self.package) - - @staticmethod - def _toposort( - package: List[LockedDependency], update: bool = False - ) -> List[LockedDependency]: - platforms = {d.platform for d in package} - - # Resort the conda packages topologically - final_package: List[LockedDependency] = [] - for platform in sorted(platforms): - from .._vendor.conda.common.toposort import toposort - - # Add the remaining non-conda packages in the order in which they appeared. - # Order the pip packages topologically ordered (might be not 100% perfect if they depend on - # other conda packages, but good enough - for manager in ["conda", "pip"]: - lookup = defaultdict(set) - packages: Dict[str, LockedDependency] = {} - - for d in package: - if d.platform != platform: - continue - - if d.manager != manager: - continue - - lookup[d.name] = set(d.dependencies) - packages[d.name] = d - - ordered = toposort(lookup) - for package_name in ordered: - # since we could have a pure dep in here, that does not have a package - # eg a pip package that depends on a conda package (the conda package will not be in this list) - dep = packages.get(package_name) - if dep is None: - continue - if dep.manager != manager: - continue - # skip virtual packages - if dep.manager == "conda" and dep.name.startswith("__"): - continue + v[i] = Channel.from_string(e) + return typing.cast(List[Channel], v) - final_package.append(dep) + def spec(self, platform: str) -> List[Dependency]: + from conda_lock.src_parser.selectors import dep_in_platform_selectors - return final_package + return [ + dep.dep + for dep in self.dependencies + if dep.selectors.platform is None + or dep_in_platform_selectors(dep, platform) + ] class LockSpecification(BaseModel): - dependencies: List[Dependency] + dependencies: Dict[str, List[Dependency]] # TODO: Should we store the auth info in here? channels: List[Channel] - platforms: List[str] sources: List[pathlib.Path] virtual_package_repo: Optional[FakeRepoData] = None + @property + def platforms(self) -> List[str]: + return list(self.dependencies.keys()) + def content_hash(self) -> Dict[str, str]: return { platform: self.content_hash_for_platform(platform) - for platform in self.platforms + for platform in self.dependencies.keys() } def content_hash_for_platform(self, platform: str) -> str: @@ -448,8 +140,9 @@ def content_hash_for_platform(self, platform: str) -> str: "channels": [c.json() for c in self.channels], "specs": [ p.dict() - for p in sorted(self.dependencies, key=lambda p: (p.manager, p.name)) - if p.selectors.for_platform(platform) + for p in sorted( + self.dependencies[platform], key=lambda p: (p.manager, p.name) + ) ], } if self.virtual_package_repo is not None: @@ -470,107 +163,98 @@ def validate_channels(cls, v: List[Union[Channel, str]]) -> List[Channel]: return typing.cast(List[Channel], v) -def _apply_categories( - requested: Dict[str, Dependency], - planned: Dict[str, LockedDependency], - categories: Sequence[str] = ("main", "dev"), -) -> None: - """map each package onto the root request the with the highest-priority category""" - # walk dependency tree to assemble all transitive dependencies by request - dependents: Dict[str, Set[str]] = {} - by_category = defaultdict(list) - - def seperator_munge_get( - d: Dict[str, LockedDependency], key: str - ) -> LockedDependency: - # since separators are not consistent across managers (or even within) we need to do some double attempts here - try: - return d[key] - except KeyError: - try: - return d[key.replace("-", "_")] - except KeyError: - return d[key.replace("_", "-")] - - for name, request in requested.items(): - todo: List[str] = list() - deps: Set[str] = set() - item = name - while True: - todo.extend( - dep - for dep in seperator_munge_get(planned, item).dependencies - # exclude virtual packages - if not (dep in deps or dep.startswith("__")) - ) - if todo: - item = todo.pop(0) - deps.add(item) - else: - break - - dependents[name] = deps - - by_category[request.category].append(request.name) - - # now, map each package to its root request - categories = [*categories, *(k for k in by_category if k not in categories)] - root_requests = {} - for category in categories: - for root in by_category.get(category, []): - for transitive_dep in dependents[root]: - if transitive_dep not in root_requests: - root_requests[transitive_dep] = root - # include root requests themselves - for name in requested: - root_requests[name] = name - - for dep, root in root_requests.items(): - source = requested[root] - # try a conda target first - target = seperator_munge_get(planned, dep) - target.category = source.category - target.optional = source.optional - - -def aggregate_lock_specs( - lock_specs: List[LockSpecification], -) -> LockSpecification: +def aggregate_deps(grouped_deps: List[List[Dependency]]) -> List[Dependency]: - # unique dependencies + # List unique dependencies unique_deps: Dict[Tuple[str, str], Dependency] = {} - for dep in chain.from_iterable( - [lock_spec.dependencies for lock_spec in lock_specs] - ): + for dep in chain.from_iterable(grouped_deps): key = (dep.manager, dep.name) - if key in unique_deps: - # Override existing, but merge selectors - previous_selectors = unique_deps[key].selectors - previous_selectors |= dep.selectors - dep.selectors = previous_selectors unique_deps[key] = dep - dependencies = list(unique_deps.values()) - try: - channels = suffix_union(lock_spec.channels or [] for lock_spec in lock_specs) - except ValueError as e: - raise ChannelAggregationError(*e.args) + return list(unique_deps.values()) - return LockSpecification( - dependencies=dependencies, - # Ensure channel are correctly ordered - channels=channels, - # uniquify metadata, preserving order - platforms=ordered_union(lock_spec.platforms or [] for lock_spec in lock_specs), - sources=ordered_union(lock_spec.sources or [] for lock_spec in lock_specs), + +def aggregate_channels( + channels: Iterable[List[Channel]], + channel_overrides: Optional[Sequence[str]] = None, +) -> List[Channel]: + if channel_overrides: + return [Channel.from_string(co) for co in channel_overrides] + else: + # Ensure channels are correctly ordered + try: + return suffix_union(channels) + except ValueError as e: + raise ChannelAggregationError(*e.args) + + +def parse_source_files( + src_file_paths: List[pathlib.Path], pip_support: bool = True +) -> List[SourceFile]: + """ + Parse a sequence of dependency specifications from source files + + Parameters + ---------- + src_files : + Files to parse for dependencies + pip_support : + Support pip dependencies + """ + from conda_lock.src_parser.environment_yaml import parse_environment_file + from conda_lock.src_parser.meta_yaml import parse_meta_yaml_file + from conda_lock.src_parser.pyproject_toml import parse_pyproject_toml + + src_files: List[SourceFile] = [] + for src_file_path in src_file_paths: + if src_file_path.name in ("meta.yaml", "meta.yml"): + src_files.append(parse_meta_yaml_file(src_file_path)) + elif src_file_path.name == "pyproject.toml": + src_files.append(parse_pyproject_toml(src_file_path)) + else: + src_files.append( + parse_environment_file( + src_file_path, + pip_support=pip_support, + ) + ) + return src_files + + +def make_lock_spec( + *, + src_file_paths: List[pathlib.Path], + virtual_package_repo: FakeRepoData, + channel_overrides: Optional[Sequence[str]] = None, + platform_overrides: Optional[Set[str]] = None, + required_categories: Optional[AbstractSet[str]] = None, + pip_support: bool = True, +) -> LockSpecification: + """Generate the lockfile specs from a set of input src_files. If required_categories is set filter out specs that do not match those""" + src_files = parse_source_files(src_file_paths, pip_support) + + # Determine Platforms to Render for + platforms = ( + platform_overrides + or {plat for sf in src_files for plat in sf.platforms} + or DEFAULT_PLATFORMS ) + spec = { + plat: aggregate_deps([sf.spec(plat) for sf in src_files]) for plat in platforms + } + + if required_categories is not None: + spec = { + plat: [d for d in deps if d.category in required_categories] + for plat, deps in spec.items() + } -class UpdateSpecification: - def __init__( - self, - locked: Optional[List[LockedDependency]] = None, - update: Optional[List[str]] = None, - ): - self.locked = locked or [] - self.update = update or [] + return LockSpecification( + dependencies=spec, + channels=aggregate_channels( + (sf.channels for sf in src_files), channel_overrides + ), + sources=src_file_paths, + virtual_package_repo=virtual_package_repo, + ) diff --git a/conda_lock/src_parser/conda_common.py b/conda_lock/src_parser/conda_common.py index 4d3c146d0..0ec7665d0 100644 --- a/conda_lock/src_parser/conda_common.py +++ b/conda_lock/src_parser/conda_common.py @@ -2,10 +2,10 @@ from .._vendor.conda.models.channel import Channel from .._vendor.conda.models.match_spec import MatchSpec -from ..src_parser import VersionedDependency +from ..src_parser import SourceDependency, VersionedDependency -def conda_spec_to_versioned_dep(spec: str, category: str) -> VersionedDependency: +def conda_spec_to_versioned_dep(spec: str, category: str) -> SourceDependency: """Convert a string form conda spec into a versioned dependency for a given category. This is used by the environment.yaml and meta.yaml specification parser @@ -30,4 +30,4 @@ def conda_spec_to_versioned_dep(spec: str, category: str) -> VersionedDependency extras=[], build=ms.get("build"), conda_channel=channel_str, - ) + ).to_source() diff --git a/conda_lock/src_parser/environment_yaml.py b/conda_lock/src_parser/environment_yaml.py index a021203fb..d6e94a73c 100644 --- a/conda_lock/src_parser/environment_yaml.py +++ b/conda_lock/src_parser/environment_yaml.py @@ -2,161 +2,86 @@ import re import sys -from typing import List, Optional, Sequence, Tuple +from typing import List, Set -import yaml +from ruamel.yaml import YAML -from conda_lock.src_parser import Dependency, LockSpecification, aggregate_lock_specs +from conda_lock.src_parser import SourceDependency, SourceFile from conda_lock.src_parser.conda_common import conda_spec_to_versioned_dep -from conda_lock.src_parser.selectors import filter_platform_selectors +from conda_lock.src_parser.selectors import parse_selector_comment_for_dep from .pyproject_toml import parse_python_requirement -_whitespace = re.compile(r"\s+") -_conda_package_pattern = re.compile(r"^(?P[A-Za-z0-9_-]+)\s?(?P.*)?$") - - -def parse_conda_requirement(req: str) -> Tuple[str, str]: - match = _conda_package_pattern.match(req) - if match: - return match.group("name"), _whitespace.sub("", match.group("version")) - else: - raise ValueError(f"Can't parse conda spec from '{req}'") - - -def _parse_environment_file_for_platform( +def parse_environment_file( environment_file: pathlib.Path, - content: str, - platform: str, *, pip_support: bool = False, -) -> LockSpecification: +) -> SourceFile: """ - Parse dependencies from a conda environment specification for an - assumed target platform. - - Parameters - ---------- - environment_file : - Path to environment.yml - pip_support : - Emit dependencies in pip section of environment.yml. If False, print a - warning and ignore pip dependencies. - platform : - Target platform to use when parsing selectors to filter lines + Parse an simple environment-yaml file in a platform and version independent way. """ - filtered_content = "\n".join(filter_platform_selectors(content, platform=platform)) - env_yaml_data = yaml.safe_load(filtered_content) - - specs = env_yaml_data["dependencies"] - channels: List[str] = env_yaml_data.get("channels", []) + if not environment_file.exists(): + raise FileNotFoundError(f"Environment File {environment_file} not found") + env_yaml_data = YAML().load(environment_file) - # These extension fields are nonstandard - platforms: List[str] = env_yaml_data.get("platforms", []) - category: str = env_yaml_data.get("category") or "main" + # Get any (nonstandard) given values in the environment file + platforms: Set[str] = set(env_yaml_data.get("platforms", [])) + category: str = str(env_yaml_data.get("category", "main")) + channels: List[str] = env_yaml_data.get("channels", []).copy() - # Split out any sub spec sections from the dependencies mapping - mapping_specs = [x for x in specs if not isinstance(x, str)] - specs = [x for x in specs if isinstance(x, str)] + all_specs = env_yaml_data["dependencies"] + specs = [x for x in all_specs if isinstance(x, str)] + mapping_specs = [x for x in all_specs if not isinstance(x, str)] - dependencies: List[Dependency] = [] - for spec in specs: - vdep = conda_spec_to_versioned_dep(spec, category) - vdep.selectors.platform = [platform] - dependencies.append(vdep) + # Get and Parse Dependencies + dependencies: List[SourceDependency] = [] + for idx, spec in enumerate(specs): + sdep = conda_spec_to_versioned_dep(spec, category) + sdep.selectors.platform = parse_selector_comment_for_dep(all_specs.ca, idx) + dependencies.append(sdep) for mapping_spec in mapping_specs: - if "pip" in mapping_spec: - if pip_support: - for spec in mapping_spec["pip"]: - if re.match(r"^-e .*$", spec): - print( - ( - f"Warning: editable pip dep '{spec}' will not be included in the lock file. " - "You will need to install it separately." - ), - file=sys.stderr, - ) - continue - - dependencies.append( - parse_python_requirement( - spec, - manager="pip", - optional=category != "main", - category=category, - normalize_name=False, - ) + if "pip" not in mapping_spec: + continue + + if pip_support: + for spec in mapping_spec["pip"]: + if re.match(r"^-e .*$", spec): + print( + ( + f"Warning: editable pip dep '{spec}' will not be included in the lock file. " + "You will need to install it separately." + ), + file=sys.stderr, + ) + continue + + dependencies.append( + parse_python_requirement( + spec, + manager="pip", + optional=category != "main", + category=category, + normalize_name=False, ) - - # ensure pip is in target env - dependencies.append(parse_python_requirement("pip", manager="conda")) - else: - print( - ( - "Warning: found pip deps, but conda-lock was installed without pypi support. " - "pip dependencies will not be included in the lock file. Either install them " - "separately, or install conda-lock with `-E pip_support`." - ), - file=sys.stderr, ) - return LockSpecification( + # ensure pip is in target env + dependencies.append(parse_python_requirement("pip", manager="conda")) + else: + print( + ( + "Warning: found pip deps, but conda-lock was installed without pypi support. " + "pip dependencies will not be included in the lock file. Either install them " + "separately, or install conda-lock with `-E pip_support`." + ), + file=sys.stderr, + ) + + return SourceFile( + file=environment_file, dependencies=dependencies, channels=channels, # type: ignore platforms=platforms, - sources=[environment_file], ) - - -def parse_environment_file( - environment_file: pathlib.Path, - given_platforms: Optional[Sequence[str]], - *, - default_platforms: List[str] = [], - pip_support: bool = False, -) -> LockSpecification: - """Parse a simple environment-yaml file for dependencies assuming the target platforms. - - * This will emit one dependency set per target platform. These may differ - if the dependencies depend on platform selectors. - * This does not support multi-output files and will ignore all lines with - selectors other than platform. - """ - if not environment_file.exists(): - raise FileNotFoundError(f"{environment_file} not found") - - with environment_file.open("r") as fo: - content = fo.read() - env_yaml_data = yaml.safe_load(content) - - # Get list of platforms from the input file - yaml_platforms: Optional[List[str]] = env_yaml_data.get("platforms") - # Final list of platforms is the following order of priority - # 1) List Passed in via the -p flag (if any given) - # 2) List From the YAML File (if specified) - # 3) Default List of Platforms to Render - platforms = list(given_platforms or yaml_platforms or default_platforms) - - # Parse with selectors for each target platform - spec = aggregate_lock_specs( - [ - _parse_environment_file_for_platform( - environment_file, content, platform, pip_support=pip_support - ) - for platform in platforms - ] - ) - - # Remove platform selectors if they apply to all targets - for dep in spec.dependencies: - if dep.selectors.platform == platforms: - dep.selectors.platform = None - - # Use the list of rendered platforms for the output spec only if - # there is a dependency that is not used on all platforms. - # This is unlike meta.yaml because environment-yaml files can contain an - # internal list of platforms, which should be used as long as it - spec.platforms = platforms - return spec diff --git a/conda_lock/src_parser/meta_yaml.py b/conda_lock/src_parser/meta_yaml.py index fd377133e..c52a2bf7a 100644 --- a/conda_lock/src_parser/meta_yaml.py +++ b/conda_lock/src_parser/meta_yaml.py @@ -1,13 +1,18 @@ import pathlib -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, List, Optional import jinja2 -import yaml -from conda_lock.common import get_in -from conda_lock.src_parser import Dependency, LockSpecification, aggregate_lock_specs -from conda_lock.src_parser.selectors import filter_platform_selectors +from ruamel.yaml import YAML + +from conda_lock.src_parser import SourceDependency, SourceFile +from conda_lock.src_parser.conda_common import conda_spec_to_versioned_dep +from conda_lock.src_parser.selectors import parse_selector_comment_for_dep + + +if TYPE_CHECKING: + from ruamel.yaml.comments import CommentedMap, CommentedSeq class UndefinedNeverFail(jinja2.Undefined): @@ -81,80 +86,45 @@ def _return_value(self, value=None): # type: ignore return value -def parse_meta_yaml_file( - meta_yaml_file: pathlib.Path, - platforms: List[str], -) -> LockSpecification: - """Parse a simple meta-yaml file for dependencies assuming the target platforms. +def parse_meta_yaml_file(meta_yaml_file: pathlib.Path) -> SourceFile: + """Parse a simple meta-yaml file for dependencies. - * This will emit one dependency set per target platform. These may differ - if the dependencies depend on platform selectors. * This does not support multi-output files and will ignore all lines with selectors other than platform. """ - # parse with selectors for each target platform - spec = aggregate_lock_specs( - [ - _parse_meta_yaml_file_for_platform(meta_yaml_file, platform) - for platform in platforms - ] - ) - # remove platform selectors if they apply to all targets - for dep in spec.dependencies: - if dep.selectors.platform == platforms: - dep.selectors.platform = None - - return spec - - -def _parse_meta_yaml_file_for_platform( - meta_yaml_file: pathlib.Path, - platform: str, -) -> LockSpecification: - """Parse a simple meta-yaml file for dependencies, assuming the target platform. - - * This does not support multi-output files and will ignore all lines with selectors other than platform - """ if not meta_yaml_file.exists(): raise FileNotFoundError(f"{meta_yaml_file} not found") with meta_yaml_file.open("r") as fo: - filtered_recipe = "\n".join( - filter_platform_selectors(fo.read(), platform=platform) - ) - t = jinja2.Template(filtered_recipe, undefined=UndefinedNeverFail) - rendered = t.render() - - meta_yaml_data = yaml.safe_load(rendered) - - channels = get_in(["extra", "channels"], meta_yaml_data, []) - dependencies: List[Dependency] = [] + recipe = fo.read() - def add_spec(spec: str, category: str) -> None: - if spec is None: - return + t = jinja2.Template(recipe, undefined=UndefinedNeverFail) + rendered = t.render() + meta_yaml_data = YAML().load(rendered) - from .conda_common import conda_spec_to_versioned_dep + channels = meta_yaml_data.mlget(["extra", "channels"], []).copy() + dependencies: List[SourceDependency] = [] - dep = conda_spec_to_versioned_dep(spec, category) - dep.selectors.platform = [platform] - dependencies.append(dep) + def add_specs(group: "CommentedSeq", category: str) -> None: + for idx, spec in enumerate(group): + if spec is None: + continue + dep = conda_spec_to_versioned_dep(spec, category) + dep.selectors.platform = parse_selector_comment_for_dep(group.ca, idx) + dependencies.append(dep) - def add_requirements_from_recipe_or_output(yaml_data: Dict[str, Any]) -> None: - for s in get_in(["requirements", "host"], yaml_data, []): - add_spec(s, "main") - for s in get_in(["requirements", "run"], yaml_data, []): - add_spec(s, "main") - for s in get_in(["test", "requires"], yaml_data, []): - add_spec(s, "dev") + def add_requirements_from_recipe_or_output(yaml_data: "CommentedMap") -> None: + add_specs(yaml_data.mlget(["requirements", "host"], []), "main") + add_specs(yaml_data.mlget(["requirements", "run"], []), "main") + add_specs(yaml_data.mlget(["test", "requires"], []), "dev") add_requirements_from_recipe_or_output(meta_yaml_data) - for output in get_in(["outputs"], meta_yaml_data, []): + for output in meta_yaml_data.get("outputs", []): add_requirements_from_recipe_or_output(output) - return LockSpecification( + return SourceFile( + file=meta_yaml_file, dependencies=dependencies, - channels=channels, - platforms=[platform], - sources=[meta_yaml_file], + channels=channels, # type: ignore + platforms=set(), ) diff --git a/conda_lock/src_parser/pyproject_toml.py b/conda_lock/src_parser/pyproject_toml.py index 82cc47639..aace69cfa 100644 --- a/conda_lock/src_parser/pyproject_toml.py +++ b/conda_lock/src_parser/pyproject_toml.py @@ -19,8 +19,8 @@ from conda_lock.common import get_in from conda_lock.lookup import get_forward_lookup as get_lookup from conda_lock.src_parser import ( - Dependency, - LockSpecification, + SourceDependency, + SourceFile, URLDependency, VersionedDependency, ) @@ -76,7 +76,7 @@ def poetry_version_to_conda_version(version_string: Optional[str]) -> Optional[s def parse_poetry_pyproject_toml( path: pathlib.Path, contents: Mapping[str, Any], -) -> LockSpecification: +) -> SourceFile: """ Parse dependencies from a poetry pyproject.toml file @@ -92,7 +92,7 @@ def parse_poetry_pyproject_toml( * markers are not supported """ - dependencies: List[Dependency] = [] + dependencies: List[SourceDependency] = [] categories = {"dependencies": "main", "dev-dependencies": "dev"} @@ -150,7 +150,7 @@ def parse_poetry_pyproject_toml( optional=optional, category=category, extras=extras, - ) + ).to_source() ) else: dependencies.append( @@ -161,15 +161,17 @@ def parse_poetry_pyproject_toml( optional=optional, category=category, extras=extras, - ) + ).to_source() ) return specification_with_dependencies(path, contents, dependencies) def specification_with_dependencies( - path: pathlib.Path, toml_contents: Mapping[str, Any], dependencies: List[Dependency] -) -> LockSpecification: + path: pathlib.Path, + toml_contents: Mapping[str, Any], + dependencies: List[SourceDependency], +) -> SourceFile: force_pypi = set() for depname, depattrs in get_in( ["tool", "conda-lock", "dependencies"], toml_contents, {} @@ -184,7 +186,7 @@ def specification_with_dependencies( optional=False, category="main", extras=[], - ) + ).to_source() ) elif isinstance(depattrs, collections.abc.Mapping): if depattrs.get("source", None) == "pypi": @@ -194,14 +196,14 @@ def specification_with_dependencies( if force_pypi: for dep in dependencies: - if dep.name in force_pypi: - dep.manager = "pip" + if dep.dep.name in force_pypi: + dep.dep.manager = "pip" - return LockSpecification( + return SourceFile( dependencies=dependencies, channels=get_in(["tool", "conda-lock", "channels"], toml_contents, []), - platforms=get_in(["tool", "conda-lock", "platforms"], toml_contents, []), - sources=[path], + platforms=set(get_in(["tool", "conda-lock", "platforms"], toml_contents, [])), + file=path, ) @@ -213,51 +215,13 @@ def to_match_spec(conda_dep_name: str, conda_version: Optional[str]) -> str: return spec -def parse_pyproject_toml( - pyproject_toml: pathlib.Path, -) -> LockSpecification: - with pyproject_toml.open("rb") as fp: - contents = toml_load(fp) - build_system = get_in(["build-system", "build-backend"], contents) - pep_621_probe = get_in(["project", "dependencies"], contents) - pdm_probe = get_in(["tool", "pdm"], contents) - parse = parse_poetry_pyproject_toml - if pep_621_probe is not None: - if pdm_probe is None: - parse = partial( - parse_requirements_pyproject_toml, - prefix=("project",), - main_tag="dependencies", - optional_tag="optional-dependencies", - ) - else: - parse = parse_pdm_pyproject_toml - elif build_system.startswith("poetry"): - parse = parse_poetry_pyproject_toml - elif build_system.startswith("flit"): - parse = partial( - parse_requirements_pyproject_toml, - prefix=("tool", "flit", "metadata"), - main_tag="requires", - optional_tag="requires-extra", - ) - else: - import warnings - - warnings.warn( - "Could not detect build-system in pyproject.toml. Assuming poetry" - ) - - return parse(pyproject_toml, contents) - - def parse_python_requirement( requirement: str, manager: Literal["conda", "pip"] = "conda", optional: bool = False, category: str = "main", normalize_name: bool = True, -) -> Dependency: +) -> SourceDependency: """Parse a requirements.txt like requirement to a conda spec""" requirement_specifier = requirement.split(";")[0].strip() from pkg_resources import Requirement @@ -286,7 +250,7 @@ def parse_python_requirement( extras=extras, url=url, hashes=[frag.replace("=", ":")], - ) + ).to_source() else: return VersionedDependency( name=conda_dep_name, @@ -295,7 +259,7 @@ def parse_python_requirement( optional=optional, category=category, extras=extras, - ) + ).to_source() def parse_requirements_pyproject_toml( @@ -305,11 +269,11 @@ def parse_requirements_pyproject_toml( main_tag: str, optional_tag: str, dev_tags: AbstractSet[str] = {"dev", "test"}, -) -> LockSpecification: +) -> SourceFile: """ PEP621 and flit """ - dependencies: List[Dependency] = [] + dependencies: List[SourceDependency] = [] sections = {(*prefix, main_tag): "main"} for extra in dev_tags: @@ -333,7 +297,7 @@ def parse_requirements_pyproject_toml( def parse_pdm_pyproject_toml( path: pathlib.Path, contents: Mapping[str, Any], -) -> LockSpecification: +) -> SourceFile: """ PDM support. First, a regular PEP621 pass; then, add all dependencies listed in the 'tool.pdm.dev-dependencies' table with the 'dev' category. @@ -348,7 +312,7 @@ def parse_pdm_pyproject_toml( dev_reqs = [] - for section, deps in get_in(["tool", "pdm", "dev-dependencies"], contents).items(): + for _, deps in get_in(["tool", "pdm", "dev-dependencies"], contents).items(): dev_reqs.extend( [ parse_python_requirement( @@ -361,3 +325,41 @@ def parse_pdm_pyproject_toml( res.dependencies.extend(dev_reqs) return res + + +def parse_pyproject_toml(pyproject_toml: pathlib.Path) -> SourceFile: + with pyproject_toml.open("rb") as fp: + contents = toml_load(fp) + + build_system: Optional[str] = get_in(["build-system", "build-backend"], contents) + pep_621_probe: Optional[str] = get_in(["project", "dependencies"], contents) + pdm_probe: Optional[str] = get_in(["tool", "pdm"], contents) + + parse = parse_poetry_pyproject_toml + if pep_621_probe is not None: + if pdm_probe is None: + parse = partial( + parse_requirements_pyproject_toml, + prefix=("project",), + main_tag="dependencies", + optional_tag="optional-dependencies", + ) + else: + parse = parse_pdm_pyproject_toml + elif build_system and build_system.startswith("poetry"): + parse = parse_poetry_pyproject_toml + elif build_system and build_system.startswith("flit"): + parse = partial( + parse_requirements_pyproject_toml, + prefix=("tool", "flit", "metadata"), + main_tag="requires", + optional_tag="requires-extra", + ) + else: + import warnings + + warnings.warn( + "Could not detect build-system in pyproject.toml. Assuming poetry" + ) + + return parse(pyproject_toml, contents) diff --git a/conda_lock/src_parser/selectors.py b/conda_lock/src_parser/selectors.py index acefffe87..772933ae0 100644 --- a/conda_lock/src_parser/selectors.py +++ b/conda_lock/src_parser/selectors.py @@ -1,16 +1,37 @@ import logging import re -from typing import Iterator, Optional +from typing import TYPE_CHECKING, List, Optional + + +if TYPE_CHECKING: + from ruamel.yaml.comments import Comment + + from conda_lock.src_parser import SourceDependency logger = logging.getLogger(__name__) +sel_pat = re.compile(r"(#.*)\[([^\[\]]+)\](?(2)[^\(\)]*)$") + + +def parse_selector_comment_for_dep( + yaml_comments: "Comment", dep_idx: int +) -> Optional[List[str]]: + if dep_idx not in yaml_comments.items: + return None + comment: str = yaml_comments.items[dep_idx][0].value + parsed_comment = comment.partition("\n")[0].rstrip() -def filter_platform_selectors( - content: str, platform: Optional[str] = None -) -> Iterator[str]: - """ """ + # This code is adapted from conda-build + m = sel_pat.match(parsed_comment) + return [m.group(2)] if m else None + + +def dep_in_platform_selectors( + source_dep: "SourceDependency", + platform: str, +) -> bool: # we support a very limited set of selectors that adhere to platform only # https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#preprocessing-selectors @@ -25,19 +46,10 @@ def filter_platform_selectors( "win-64": {"win", "win64"}, } - # This code is adapted from conda-build - sel_pat = re.compile(r"(.+?)\s*(#.*)\[([^\[\]]+)\](?(2)[^\(\)]*)$") - for line in content.splitlines(keepends=False): - if line.lstrip().startswith("#"): - continue - m = sel_pat.match(line) - if platform and m: - cond = m.group(3) - if cond in platform_sel[platform]: - yield line - else: - logger.warning( - "filtered out line `%s` due to unmatchable selector", line - ) - else: - yield line + return platform in platform_sel and ( + source_dep.selectors.platform is None + or any( + sel_elem in platform_sel[platform] + for sel_elem in source_dep.selectors.platform + ) + ) diff --git a/tests/test_conda_lock.py b/tests/test_conda_lock.py index 22c7d96b5..770ce26ae 100644 --- a/tests/test_conda_lock.py +++ b/tests/test_conda_lock.py @@ -13,8 +13,9 @@ import uuid from glob import glob +from itertools import chain from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List, Optional from unittest.mock import MagicMock from urllib.parse import urldefrag, urlsplit @@ -30,20 +31,17 @@ from conda_lock.conda_lock import ( DEFAULT_FILES, DEFAULT_LOCKFILE_NAME, - DEFAULT_PLATFORMS, _add_auth_to_line, _add_auth_to_lockfile, _extract_domain, _strip_auth_from_line, _strip_auth_from_lockfile, - aggregate_lock_specs, create_lockfile_from_spec, default_virtual_package_repodata, determine_conda_executable, extract_input_hash, main, make_lock_spec, - parse_meta_yaml_file, run_lock, ) from conda_lock.conda_solver import extract_json_object, fake_conda_environment @@ -58,18 +56,21 @@ is_micromamba, reset_conda_pkgs_dir, ) +from conda_lock.lockfile import HashModel, LockedDependency, parse_conda_lock_file from conda_lock.models.channel import Channel from conda_lock.pypi_solver import parse_pip_requirement, solve_pypi from conda_lock.src_parser import ( - HashModel, - LockedDependency, + Dependency, LockSpecification, MetadataOption, Selectors, + SourceDependency, VersionedDependency, + aggregate_channels, + aggregate_deps, ) from conda_lock.src_parser.environment_yaml import parse_environment_file -from conda_lock.src_parser.lockfile import parse_conda_lock_file +from conda_lock.src_parser.meta_yaml import parse_meta_yaml_file from conda_lock.src_parser.pyproject_toml import ( parse_pyproject_toml, poetry_version_to_conda_version, @@ -119,6 +120,18 @@ def gdal_environment(tmp_path: Path): return x +@pytest.fixture +def lock_spec_env(tmp_path: Path): + main = clone_test_dir("test-lock-spec", tmp_path) + return [main.joinpath("base-env.yml"), main.joinpath("ml-stuff.yml")] + + +@pytest.fixture +def lock_spec_env_platforms(tmp_path: Path): + main = clone_test_dir("test-lock-spec-platforms", tmp_path) + return [main.joinpath("base-env.yml"), main.joinpath("osx-stuff.yml")] + + @pytest.fixture def filter_conda_environment(tmp_path: Path): x = clone_test_dir("test-env-filter-platform", tmp_path).joinpath("environment.yml") @@ -281,17 +294,26 @@ def custom_json_metadata(custom_metadata_environment: Path) -> Path: return outfile +def _make_source_dep( + name: str, platform: Optional[str] = None, **kwargs +) -> SourceDependency: + return SourceDependency( + dep=VersionedDependency(name=name, **kwargs), + selectors=Selectors(platform=([platform] if platform else None)), + ) + + def test_parse_environment_file(gdal_environment: Path): - res = parse_environment_file(gdal_environment, DEFAULT_PLATFORMS, pip_support=True) + res = parse_environment_file(gdal_environment, pip_support=True) assert all( x in res.dependencies for x in [ - VersionedDependency( + _make_source_dep( name="python", manager="conda", version=">=3.7,<3.8", ), - VersionedDependency( + _make_source_dep( name="gdal", manager="conda", version="", @@ -299,7 +321,7 @@ def test_parse_environment_file(gdal_environment: Path): ] ) assert ( - VersionedDependency( + _make_source_dep( name="toolz", manager="pip", version="*", @@ -312,9 +334,9 @@ def test_parse_environment_file(gdal_environment: Path): def test_parse_environment_file_with_pip(pip_environment: Path): - res = parse_environment_file(pip_environment, DEFAULT_PLATFORMS, pip_support=True) - assert [dep for dep in res.dependencies if dep.manager == "pip"] == [ - VersionedDependency( + res = parse_environment_file(pip_environment, pip_support=True) + assert [dep for dep in res.dependencies if dep.dep.manager == "pip"] == [ + _make_source_dep( name="requests-toolbelt", manager="pip", optional=False, @@ -326,73 +348,70 @@ def test_parse_environment_file_with_pip(pip_environment: Path): def test_parse_env_file_with_filters_no_args(filter_conda_environment: Path): - res = parse_environment_file(filter_conda_environment, None, pip_support=False) + res = parse_environment_file(filter_conda_environment, pip_support=False) assert all(x in res.platforms for x in ["osx-arm64", "osx-64", "linux-64"]) assert res.channels == [Channel.from_string("conda-forge")] assert all( x in res.dependencies for x in [ - VersionedDependency( + _make_source_dep( name="python", manager="conda", version="<3.11", ), - VersionedDependency( + _make_source_dep( name="clang_osx-arm64", manager="conda", version="", - selectors=Selectors(platform=["osx-arm64"]), + platform="arm64", ), - VersionedDependency( + _make_source_dep( name="clang_osx-64", manager="conda", version="", - selectors=Selectors(platform=["osx-64"]), + platform="osx64", ), - VersionedDependency( + _make_source_dep( name="gcc_linux-64", manager="conda", version=">=6", - selectors=Selectors(platform=["linux-64"]), + platform="linux64", ), ] ) def test_parse_env_file_with_filters_defaults(filter_conda_environment: Path): - res = parse_environment_file( - filter_conda_environment, DEFAULT_PLATFORMS, pip_support=False - ) - assert all(x in res.platforms for x in DEFAULT_PLATFORMS) + res = parse_environment_file(filter_conda_environment, pip_support=False) + assert all(x in res.platforms for x in ["osx-arm64", "osx-64", "linux-64"]) assert res.channels == [Channel.from_string("conda-forge")] assert all( x in res.dependencies for x in [ - VersionedDependency( + _make_source_dep( name="python", manager="conda", version="<3.11", ), - VersionedDependency( + _make_source_dep( name="clang_osx-64", manager="conda", version="", - selectors=Selectors(platform=["osx-64"]), + platform="osx64", ), - VersionedDependency( + _make_source_dep( name="gcc_linux-64", manager="conda", version=">=6", - selectors=Selectors(platform=["linux-64"]), + platform="linux64", ), ] ) def test_choose_wheel() -> None: - solution = solve_pypi( { "fastavro": VersionedDependency( @@ -505,31 +524,28 @@ def test_parse_pip_requirement( def test_parse_meta_yaml_file(meta_yaml_environment: Path): - res = parse_meta_yaml_file(meta_yaml_environment, ["linux-64", "osx-64"]) - specs = {dep.name: dep for dep in res.dependencies} + res = parse_meta_yaml_file(meta_yaml_environment) + specs = {dep.dep.name: dep for dep in res.dependencies} assert all(x in specs for x in ["python", "numpy"]) - assert all( - dep.selectors - == Selectors( - platform=None - ) # Platform will be set to None if all dependencies are the same - for dep in specs.values() - ) + # Ensure that this dep specified by a python selector is ignored - assert "enum34" not in specs + assert "enum34" in specs + assert specs["enum34"].selectors == Selectors(platform=["py27"]) + # Ensure that this platform specific dep is included assert "zlib" in specs - assert specs["pytest"].category == "dev" - assert specs["pytest"].optional is True + assert specs["zlib"].selectors == Selectors(platform=["unix"]) + + assert specs["pytest"].dep.category == "dev" + assert specs["pytest"].dep.optional is True def test_parse_poetry(poetry_pyproject_toml: Path): - res = parse_pyproject_toml( - poetry_pyproject_toml, - ) + res = parse_pyproject_toml(poetry_pyproject_toml) specs = { - dep.name: typing.cast(VersionedDependency, dep) for dep in res.dependencies + dep.dep.name: typing.cast(VersionedDependency, dep.dep) + for dep in res.dependencies } assert specs["requests"].version == ">=2.13.0,<3.0.0" @@ -551,41 +567,43 @@ def test_spec_poetry(poetry_pyproject_toml: Path): virtual_package_repo = default_virtual_package_repodata() with virtual_package_repo: spec = make_lock_spec( - src_files=[poetry_pyproject_toml], virtual_package_repo=virtual_package_repo + src_file_paths=[poetry_pyproject_toml], + virtual_package_repo=virtual_package_repo, ) - deps = {d.name for d in spec.dependencies} - assert "tomlkit" in deps - assert "pytest" in deps - assert "requests" in deps + for plat_deps in spec.dependencies.values(): + deps = {d.name for d in plat_deps} + assert "tomlkit" in deps + assert "pytest" in deps + assert "requests" in deps spec = make_lock_spec( - src_files=[poetry_pyproject_toml], + src_file_paths=[poetry_pyproject_toml], virtual_package_repo=virtual_package_repo, required_categories={"main", "dev"}, ) - deps = {d.name for d in spec.dependencies} - assert "tomlkit" not in deps - assert "pytest" in deps - assert "requests" in deps + for plat_deps in spec.dependencies.values(): + deps = {d.name for d in plat_deps} + assert "tomlkit" not in deps + assert "pytest" in deps + assert "requests" in deps spec = make_lock_spec( - src_files=[poetry_pyproject_toml], + src_file_paths=[poetry_pyproject_toml], virtual_package_repo=virtual_package_repo, required_categories={"main"}, ) - deps = {d.name for d in spec.dependencies} - assert "tomlkit" not in deps - assert "pytest" not in deps - assert "requests" in deps + for plat_deps in spec.dependencies.values(): + deps = {d.name for d in plat_deps} + assert "tomlkit" not in deps + assert "pytest" not in deps + assert "requests" in deps def test_parse_flit(flit_pyproject_toml: Path): - res = parse_pyproject_toml( - flit_pyproject_toml, - ) - + res = parse_pyproject_toml(flit_pyproject_toml) specs = { - dep.name: typing.cast(VersionedDependency, dep) for dep in res.dependencies + dep.dep.name: typing.cast(VersionedDependency, dep.dep) + for dep in res.dependencies } assert specs["requests"].version == ">=2.13.0" @@ -602,12 +620,10 @@ def test_parse_flit(flit_pyproject_toml: Path): def test_parse_pdm(pdm_pyproject_toml: Path): - res = parse_pyproject_toml( - pdm_pyproject_toml, - ) - + res = parse_pyproject_toml(pdm_pyproject_toml) specs = { - dep.name: typing.cast(VersionedDependency, dep) for dep in res.dependencies + dep.dep.name: typing.cast(VersionedDependency, dep.dep) + for dep in res.dependencies } # Base dependencies @@ -926,11 +942,11 @@ def test_run_lock_with_local_package( with virtual_package_repo: lock_spec = make_lock_spec( - src_files=[pip_local_package_environment], + src_file_paths=[pip_local_package_environment], virtual_package_repo=virtual_package_repo, ) assert not any( - p.manager == "pip" for p in lock_spec.dependencies + p.manager == "pip" for p in chain(*lock_spec.dependencies.values()) ), "conda-lock ignores editable pip deps" @@ -991,18 +1007,19 @@ def test_poetry_version_parsing_constraints( with vpr, capsys.disabled(): with tempfile.NamedTemporaryFile(dir=".") as tf: spec = LockSpecification( - dependencies=[ - VersionedDependency( - name=package, - version=poetry_version_to_conda_version(version) or "", - manager="conda", - optional=False, - category="main", - extras=[], - ) - ], + dependencies={ + "linux-64": [ + VersionedDependency( + name=package, + version=poetry_version_to_conda_version(version) or "", + manager="conda", + optional=False, + category="main", + extras=[], + ) + ] + }, channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], # NB: this file must exist for relative path resolution to work # in create_lockfile_from_spec sources=[Path(tf.name)], @@ -1047,145 +1064,79 @@ def _make_spec(name: str, constraint: str = "*"): ) -def _make_dependency_with_platforms( - name: str, platforms: typing.List[str], constraint: str = "*" -): - return VersionedDependency( - name=name, - version=constraint, - selectors=Selectors(platform=platforms), - ) - - -def test_aggregate_lock_specs(): - """Ensure that the way two specs combine when both specify channels is correct""" - base_spec = LockSpecification( - dependencies=[_make_spec("python", "=3.7")], - channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], - sources=[Path("base-env.yml")], - ) - - gpu_spec = LockSpecification( - dependencies=[_make_spec("pytorch")], - channels=[Channel.from_string("pytorch"), Channel.from_string("conda-forge")], - platforms=["linux-64"], - sources=[Path("ml-stuff.yml")], - ) - - # NB: content hash explicitly does not depend on the source file names - actual = aggregate_lock_specs([base_spec, gpu_spec]) +def test_lock_spec(lock_spec_env: List[Path]): + """Ensure that the way two files combine when both specify channels is correct""" + actual = make_lock_spec(src_file_paths=lock_spec_env, virtual_package_repo=None) # type: ignore expected = LockSpecification( - dependencies=[ - _make_spec("python", "=3.7"), - _make_spec("pytorch"), - ], + dependencies={ + "linux-64": [ + _make_spec("python", "3.7.*"), + _make_spec("pytorch", constraint=""), + ] + }, channels=[ Channel.from_string("pytorch"), Channel.from_string("conda-forge"), ], - platforms=["linux-64"], - sources=[], + sources=lock_spec_env, ) assert actual.dict(exclude={"sources"}) == expected.dict(exclude={"sources"}) + # NB: content hash explicitly does not depend on the source file names assert actual.content_hash() == expected.content_hash() -def test_aggregate_lock_specs_multiple_platforms(): +def test_aggregate_lock_specs_multiple_platforms(lock_spec_env_platforms: List[Path]): """Ensure that plaforms are merged correctly""" - linux_spec = LockSpecification( - dependencies=[_make_dependency_with_platforms("python", ["linux-64"], "=3.7")], - channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], - sources=[Path("base-env.yml")], - ) - - osx_spec = LockSpecification( - dependencies=[_make_dependency_with_platforms("python", ["osx-64"], "=3.7")], - channels=[Channel.from_string("conda-forge")], - platforms=["osx-64"], - sources=[Path("base-env.yml")], - ) - - # NB: content hash explicitly does not depend on the source file names - actual = aggregate_lock_specs([linux_spec, osx_spec]) + actual = make_lock_spec(src_file_paths=lock_spec_env_platforms, virtual_package_repo=None) # type: ignore expected = LockSpecification( - dependencies=[ - _make_dependency_with_platforms("python", ["linux-64", "osx-64"], "=3.7") - ], + dependencies={ + "linux-64": [_make_spec("python", "3.7.*")], + "osx-64": [_make_spec("python", "3.7.*")], + }, channels=[ Channel.from_string("conda-forge"), ], - platforms=["linux-64", "osx-64"], - sources=[], + sources=lock_spec_env_platforms, ) assert actual.dict(exclude={"sources"}) == expected.dict(exclude={"sources"}) + # NB: content hash explicitly does not depend on the source file names assert actual.content_hash() == expected.content_hash() def test_aggregate_lock_specs_override_version(): - base_spec = LockSpecification( - dependencies=[_make_spec("package", "=1.0")], - channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], - sources=[Path("base.yml")], - ) + base_deps: List[Dependency] = [_make_spec("package", "=1.0")] + override_deps: List[Dependency] = [_make_spec("package", "=2.0")] + agg_deps = aggregate_deps([base_deps, override_deps]) + assert agg_deps == override_deps - override_spec = LockSpecification( - dependencies=[_make_spec("package", "=2.0")], - channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")], - platforms=["linux-64"], - sources=[Path("override.yml")], - ) - agg_spec = aggregate_lock_specs([base_spec, override_spec]) - - assert agg_spec.dependencies == override_spec.dependencies - - -def test_aggregate_lock_specs_invalid_channels(): +def test_aggregate_channels(): """Ensure that aggregating specs from mismatched channel orderings raises an error.""" - base_spec = LockSpecification( - dependencies=[], - channels=[Channel.from_string("defaults")], - platforms=[], - sources=[], - ) + base = [Channel.from_string("defaults")] - add_conda_forge = base_spec.copy( - update={ - "channels": [ - Channel.from_string("conda-forge"), - Channel.from_string("defaults"), - ] - } - ) - agg_spec = aggregate_lock_specs([base_spec, add_conda_forge]) - assert agg_spec.channels == add_conda_forge.channels + add_conda_forge = [ + Channel.from_string("conda-forge"), + Channel.from_string("defaults"), + ] + agg_spec = aggregate_channels([base, add_conda_forge]) + assert agg_spec == add_conda_forge # swap the order of the two channels, which is an error - flipped = base_spec.copy( - update={ - "channels": [ - Channel.from_string("defaults"), - Channel.from_string("conda-forge"), - ] - } - ) + flipped = [ + Channel.from_string("defaults"), + Channel.from_string("conda-forge"), + ] with pytest.raises(ChannelAggregationError): - agg_spec = aggregate_lock_specs([base_spec, add_conda_forge, flipped]) + agg_spec = aggregate_channels([base, add_conda_forge, flipped]) + + add_pytorch = [ + Channel.from_string("pytorch"), + Channel.from_string("defaults"), + ] - add_pytorch = base_spec.copy( - update={ - "channels": [ - Channel.from_string("pytorch"), - Channel.from_string("defaults"), - ] - } - ) with pytest.raises(ChannelAggregationError): - agg_spec = aggregate_lock_specs([base_spec, add_conda_forge, add_pytorch]) + agg_spec = aggregate_channels([base, add_conda_forge, add_pytorch]) @pytest.fixture(scope="session") @@ -1523,9 +1474,8 @@ def test_virtual_package_input_hash_stability(): vpr = virtual_package_repo_from_specification(vspec) spec = LockSpecification( - dependencies=[], + dependencies={"linux-64": []}, channels=[], - platforms=["linux-64"], sources=[], virtual_package_repo=vpr, ) @@ -1548,9 +1498,8 @@ def test_default_virtual_package_input_hash_stability(): } spec = LockSpecification( - dependencies=[], + dependencies={key: [] for key in expected.keys()}, channels=[], - platforms=list(expected.keys()), sources=[], virtual_package_repo=vpr, )