From d4e91efc7da2ac4a597072f1f56b3f69886826c7 Mon Sep 17 00:00:00 2001 From: Tarun Pratap Singh <101409098+Wackyator@users.noreply.github.com> Date: Mon, 2 Oct 2023 18:43:23 +0530 Subject: [PATCH] feat: add solver (#361) This PR achieves functionality described in #353 TODO: - [x] Test --- py-rattler/Cargo.lock | 84 +++++++++++++++++++++++++++ py-rattler/Cargo.toml | 2 +- py-rattler/pixi.lock | 21 +++++++ py-rattler/pixi.toml | 4 ++ py-rattler/rattler/__init__.py | 4 ++ py-rattler/rattler/channel/channel.py | 4 +- py-rattler/rattler/solver/__init__.py | 3 + py-rattler/rattler/solver/solver.py | 58 ++++++++++++++++++ py-rattler/src/error.rs | 11 ++++ py-rattler/src/lib.rs | 5 ++ py-rattler/src/match_spec.rs | 2 +- py-rattler/src/repo_data/sparse.rs | 12 ++-- py-rattler/src/solver.rs | 49 ++++++++++++++++ py-rattler/tests/unit/test_solver.py | 32 ++++++++++ 14 files changed, 284 insertions(+), 7 deletions(-) create mode 100644 py-rattler/rattler/solver/__init__.py create mode 100644 py-rattler/rattler/solver/solver.py create mode 100644 py-rattler/src/solver.rs create mode 100644 py-rattler/tests/unit/test_solver.py diff --git a/py-rattler/Cargo.lock b/py-rattler/Cargo.lock index 8db4d65bc..ec182bf7d 100644 --- a/py-rattler/Cargo.lock +++ b/py-rattler/Cargo.lock @@ -248,6 +248,18 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake2" version = "0.10.6" @@ -558,6 +570,15 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "elsa" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714f766f3556b44e7e4776ad133fcc3445a489517c25c704ace411bb14790194" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "encoding_rs" version = "0.8.33" @@ -660,6 +681,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.0.27" @@ -700,6 +727,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "futures" version = "0.1.31" @@ -1627,6 +1660,16 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +[[package]] +name = "petgraph" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" +dependencies = [ + "fixedbitset", + "indexmap 2.0.0", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -1836,6 +1879,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -2071,6 +2120,7 @@ dependencies = [ "rattler_conda_types", "rattler_digest", "rattler_libsolv_c", + "resolvo", "serde", "tempfile", "thiserror", @@ -2192,6 +2242,19 @@ dependencies = [ "winreg", ] +[[package]] +name = "resolvo" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dab30801b54723f1949c6453a35db09c89e2ce7e052dc63e715f32fb40e427c" +dependencies = [ + "bitvec", + "elsa", + "itertools", + "petgraph", + "tracing", +] + [[package]] name = "retry-policies" version = "0.2.0" @@ -2502,6 +2565,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "static_assertions" version = "1.1.0" @@ -2570,6 +2639,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tar" version = "0.4.40" @@ -3116,6 +3191,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "xattr" version = "1.0.1" diff --git a/py-rattler/Cargo.toml b/py-rattler/Cargo.toml index 98c3576c2..0de38e5a3 100644 --- a/py-rattler/Cargo.toml +++ b/py-rattler/Cargo.toml @@ -20,7 +20,7 @@ rattler_conda_types = { path = "../crates/rattler_conda_types", default-features rattler_networking = { path = "../crates/rattler_networking", default-features = false } rattler_shell = { path = "../crates/rattler_shell", default-features = false } rattler_virtual_packages = { path = "../crates/rattler_virtual_packages" } -rattler_solve = { path = "../crates/rattler_solve" } +rattler_solve = { path = "../crates/rattler_solve", features = ["resolvo"] } pyo3 = { version = "0.19", features = [ "abi3-py38", diff --git a/py-rattler/pixi.lock b/py-rattler/pixi.lock index 7e25dd991..bbfa9c170 100644 --- a/py-rattler/pixi.lock +++ b/py-rattler/pixi.lock @@ -1779,6 +1779,27 @@ package: noarch: python size: 46098 timestamp: 1681337144376 +- name: patchelf + version: 0.17.2 + manager: conda + platform: linux-64 + dependencies: + libgcc-ng: '>=7.5.0' + libstdcxx-ng: '>=7.5.0' + url: https://conda.anaconda.org/conda-forge/linux-64/patchelf-0.17.2-h58526e2_0.conda + hash: + md5: ba76a6a448819560b5f8b08a9c74f415 + sha256: eb355ac225be2f698e19dba4dcab7cb0748225677a9799e9cc8e4cadc3cb738f + optional: false + category: main + build: h58526e2_0 + arch: x86_64 + subdir: linux-64 + build_number: 0 + license: GPL-3.0-or-later + license_family: GPL + size: 94048 + timestamp: 1673473024463 - name: pathspec version: 0.11.2 manager: conda diff --git a/py-rattler/pixi.toml b/py-rattler/pixi.toml index dafb3165b..a8d10b66c 100644 --- a/py-rattler/pixi.toml +++ b/py-rattler/pixi.toml @@ -12,6 +12,7 @@ license = "BSD-3-Clause" [tasks] build = "PIP_REQUIRE_VIRTUALENV=false maturin develop" +build-release = "PIP_REQUIRE_VIRTUALENV=false maturin develop --release" test = { cmd = "pytest --doctest-modules", depends_on = ["build"] } fmt-python = "black ." fmt-rust = "cargo fmt --all" @@ -34,3 +35,6 @@ black = "~=23.7.0" ruff = "~=0.0.285" mypy = "~=1.5.1" pytest-asyncio = "0.21.1.*" + +[target.linux-64.dependencies] +patchelf = "~=0.17.2" diff --git a/py-rattler/rattler/__init__.py b/py-rattler/rattler/__init__.py index 4104d03cc..5076cf9d8 100644 --- a/py-rattler/rattler/__init__.py +++ b/py-rattler/rattler/__init__.py @@ -12,6 +12,8 @@ from rattler.virtual_package import GenericVirtualPackage, VirtualPackage from rattler.package import PackageName from rattler.prefix import PrefixRecord, PrefixPaths +from rattler.solver import solve +from rattler.platform import Platform __all__ = [ "Version", @@ -32,4 +34,6 @@ "PrefixRecord", "PrefixPaths", "SparseRepoData", + "solve", + "Platform", ] diff --git a/py-rattler/rattler/channel/channel.py b/py-rattler/rattler/channel/channel.py index 29bd5b533..e22eebc30 100644 --- a/py-rattler/rattler/channel/channel.py +++ b/py-rattler/rattler/channel/channel.py @@ -6,7 +6,9 @@ class Channel: - def __init__(self, name: str, channel_configuration: ChannelConfig) -> None: + def __init__( + self, name: str, channel_configuration: ChannelConfig = ChannelConfig() + ) -> None: """ Create a new channel. diff --git a/py-rattler/rattler/solver/__init__.py b/py-rattler/rattler/solver/__init__.py new file mode 100644 index 000000000..084e30ab0 --- /dev/null +++ b/py-rattler/rattler/solver/__init__.py @@ -0,0 +1,3 @@ +from rattler.solver.solver import solve + +__all__ = ["solve"] diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py new file mode 100644 index 000000000..d443c5b28 --- /dev/null +++ b/py-rattler/rattler/solver/solver.py @@ -0,0 +1,58 @@ +from __future__ import annotations +from typing import List, Optional +from rattler.match_spec.match_spec import MatchSpec + +from rattler.rattler import py_solve +from rattler.repo_data.record import RepoDataRecord +from rattler.repo_data.sparse import SparseRepoData +from rattler.virtual_package.generic import GenericVirtualPackage + + +def solve( + specs: List[MatchSpec], + available_packages: List[SparseRepoData], + locked_packages: Optional[List[RepoDataRecord]] = None, + pinned_packages: Optional[List[RepoDataRecord]] = None, + virtual_packages: Optional[List[GenericVirtualPackage]] = None, +) -> List[RepoDataRecord]: + """ + Resolve the dependencies and return the `RepoDataRecord`s + that should be present in the environment. + + Arguments: + specs: A list of matchspec to solve. + available_packages: A list of RepoData to use for solving the `specs`. + locked_packages: Records of packages that are previously selected. + If the solver encounters multiple variants of a single + package (identified by its name), it will sort the records + and select the best possible version. However, if there + exists a locked version it will prefer that variant instead. + This is useful to reduce the number of packages that are + updated when installing new packages. Usually you add the + currently installed packages or packages from a lock-file here. + pinned_packages: Records of packages that are previously selected and CANNOT + be changed. If the solver encounters multiple variants of + a single package (identified by its name), it will sort the + records and select the best possible version. However, if + there is a variant available in the `pinned_packages` field it + will always select that version no matter what even if that + means other packages have to be downgraded. + virtual_packages: A list of virtual packages considered active. + + Returns: + Resolved list of `RepoDataRecord`s. + """ + + return [ + RepoDataRecord._from_py_record(solved_package) + for solved_package in py_solve( + [spec._match_spec for spec in specs], + [package._sparse for package in available_packages], + [package._record for package in locked_packages or []], + [package._record for package in pinned_packages or []], + [ + v_package._generic_virtual_package + for v_package in virtual_packages or [] + ], + ) + ] diff --git a/py-rattler/src/error.rs b/py-rattler/src/error.rs index b806e7e22..efdc7cd91 100644 --- a/py-rattler/src/error.rs +++ b/py-rattler/src/error.rs @@ -8,6 +8,7 @@ use rattler_conda_types::{ }; use rattler_repodata_gateway::fetch::FetchRepoDataError; use rattler_shell::activation::ActivationError; +use rattler_solve::SolveError; use rattler_virtual_packages::DetectVirtualPackageError; use thiserror::Error; @@ -38,6 +39,10 @@ pub enum PyRattlerError { DetectVirtualPackageError(#[from] DetectVirtualPackageError), #[error(transparent)] IoError(#[from] io::Error), + #[error(transparent)] + SolverError(#[from] SolveError), + #[error("invalid 'SparseRepoData' object found")] + InvalidSparseDataError, } impl From for PyErr { @@ -69,6 +74,10 @@ impl From for PyErr { DetectVirtualPackageException::new_err(err.to_string()) } PyRattlerError::IoError(err) => IoException::new_err(err.to_string()), + PyRattlerError::SolverError(err) => SolverException::new_err(err.to_string()), + PyRattlerError::InvalidSparseDataError => InvalidSparseDataException::new_err( + PyRattlerError::InvalidSparseDataError.to_string(), + ), } } } @@ -85,3 +94,5 @@ create_exception!(exceptions, FetchRepoDataException, PyException); create_exception!(exceptions, CacheDirException, PyException); create_exception!(exceptions, DetectVirtualPackageException, PyException); create_exception!(exceptions, IoException, PyException); +create_exception!(exceptions, SolverException, PyException); +create_exception!(exceptions, InvalidSparseDataException, PyException); diff --git a/py-rattler/src/lib.rs b/py-rattler/src/lib.rs index e3429efc7..dd0264ffc 100644 --- a/py-rattler/src/lib.rs +++ b/py-rattler/src/lib.rs @@ -9,6 +9,7 @@ mod platform; mod prefix_record; mod repo_data; mod shell; +mod solver; mod version; mod virtual_package; @@ -34,6 +35,7 @@ use pyo3::prelude::*; use platform::{PyArch, PyPlatform}; use shell::{PyActivationResult, PyActivationVariables, PyActivator, PyShellEnum}; +use solver::py_solve; use virtual_package::PyVirtualPackage; #[pymodule] @@ -71,6 +73,9 @@ fn rattler(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::().unwrap(); m.add_class::().unwrap(); + m.add_function(wrap_pyfunction!(py_solve, m).unwrap()) + .unwrap(); + // Exceptions m.add( "InvalidVersionError", diff --git a/py-rattler/src/match_spec.rs b/py-rattler/src/match_spec.rs index 7c2a3a241..3c52b7f18 100644 --- a/py-rattler/src/match_spec.rs +++ b/py-rattler/src/match_spec.rs @@ -11,7 +11,7 @@ use crate::{ #[repr(transparent)] #[derive(Clone)] pub struct PyMatchSpec { - inner: MatchSpec, + pub(crate) inner: MatchSpec, } impl From for PyMatchSpec { diff --git a/py-rattler/src/repo_data/sparse.rs b/py-rattler/src/repo_data/sparse.rs index e387219ec..b8e5be46f 100644 --- a/py-rattler/src/repo_data/sparse.rs +++ b/py-rattler/src/repo_data/sparse.rs @@ -23,6 +23,12 @@ impl From for PySparseRepoData { } } +impl<'a> From<&'a PySparseRepoData> for &'a SparseRepoData { + fn from(value: &'a PySparseRepoData) -> Self { + value.inner.as_ref() + } +} + #[pymethods] impl PySparseRepoData { #[new] @@ -57,11 +63,9 @@ impl PySparseRepoData { repo_data: Vec, package_names: Vec, ) -> PyResult>> { - let repo_data = repo_data.iter().map(|r| r.inner.as_ref()); - let package_names = package_names.into_iter().map(Into::into); - - // release gil to allow other threads to progress py.allow_threads(move || { + let repo_data = repo_data.iter().map(Into::into); + let package_names = package_names.into_iter().map(Into::into); Ok( SparseRepoData::load_records_recursive(repo_data, package_names, None)? .into_iter() diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs new file mode 100644 index 000000000..380b4d213 --- /dev/null +++ b/py-rattler/src/solver.rs @@ -0,0 +1,49 @@ +use pyo3::{pyfunction, PyResult, Python}; +use rattler_repodata_gateway::sparse::SparseRepoData; +use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; + +use crate::{ + error::PyRattlerError, + generic_virtual_package::PyGenericVirtualPackage, + match_spec::PyMatchSpec, + repo_data::{repo_data_record::PyRepoDataRecord, sparse::PySparseRepoData}, +}; + +#[pyfunction] +pub fn py_solve( + py: Python<'_>, + specs: Vec, + available_packages: Vec, + locked_packages: Vec, + pinned_packages: Vec, + virtual_packages: Vec, +) -> PyResult> { + py.allow_threads(move || { + let package_names = specs + .iter() + .filter_map(|match_spec| match_spec.inner.name.clone()); + + let available_packages = SparseRepoData::load_records_recursive( + available_packages.iter().map(Into::into), + package_names, + None, + )?; + + let task = SolverTask { + available_packages: &available_packages, + locked_packages: locked_packages.into_iter().map(Into::into).collect(), + pinned_packages: pinned_packages.into_iter().map(Into::into).collect(), + virtual_packages: virtual_packages.into_iter().map(Into::into).collect(), + specs: specs.into_iter().map(Into::into).collect(), + }; + + Ok(Solver + .solve(task) + .map(|res| { + res.into_iter() + .map(Into::into) + .collect::>() + }) + .map_err(PyRattlerError::from)?) + }) +} diff --git a/py-rattler/tests/unit/test_solver.py b/py-rattler/tests/unit/test_solver.py new file mode 100644 index 000000000..a9354ea95 --- /dev/null +++ b/py-rattler/tests/unit/test_solver.py @@ -0,0 +1,32 @@ +# type: ignore +import os.path + +import pytest +from rattler import ( + solve, + Channel, + MatchSpec, + RepoDataRecord, + SparseRepoData, +) + + +@pytest.mark.asyncio +async def test_solve(): + linux64_chan = Channel("conda-forge") + data_dir = os.path.join(os.path.dirname(__file__), "../../../test-data/") + linux64_path = os.path.join(data_dir, "channels/conda-forge/linux-64/repodata.json") + linux64_data = SparseRepoData( + channel=linux64_chan, + subdir="linux-64", + path=linux64_path, + ) + + solved_data = solve( + [MatchSpec("python"), MatchSpec("sqlite")], + [linux64_data], + ) + + assert isinstance(solved_data, list) + assert isinstance(solved_data[0], RepoDataRecord) + assert len(solved_data) == 19