Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add solver #361

Merged
merged 12 commits into from
Oct 2, 2023
84 changes: 84 additions & 0 deletions py-rattler/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion py-rattler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ rattler_conda_types = { path = "../crates/rattler_conda_types", default-features
rattler_networking = { path = "../crates/rattler_networking", default-features = false }
rattler_shell = { path = "../crates/rattler_shell", default-features = false }
rattler_virtual_packages = { path = "../crates/rattler_virtual_packages" }
rattler_solve = { path = "../crates/rattler_solve" }
rattler_solve = { path = "../crates/rattler_solve", features = ["resolvo"] }

pyo3 = { version = "0.19", features = [
"abi3-py38",
Expand Down
21 changes: 21 additions & 0 deletions py-rattler/pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions py-rattler/pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ license = "BSD-3-Clause"

[tasks]
build = "PIP_REQUIRE_VIRTUALENV=false maturin develop"
build-release = "PIP_REQUIRE_VIRTUALENV=false maturin develop --release"
test = { cmd = "pytest --doctest-modules", depends_on = ["build"] }
fmt-python = "black ."
fmt-rust = "cargo fmt --all"
Expand All @@ -34,3 +35,6 @@ black = "~=23.7.0"
ruff = "~=0.0.285"
mypy = "~=1.5.1"
pytest-asyncio = "0.21.1.*"

[target.linux-64.dependencies]
patchelf = "~=0.17.2"
4 changes: 4 additions & 0 deletions py-rattler/rattler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from rattler.virtual_package import GenericVirtualPackage, VirtualPackage
from rattler.package import PackageName
from rattler.prefix import PrefixRecord, PrefixPaths
from rattler.solver import solve
from rattler.platform import Platform

__all__ = [
"Version",
Expand All @@ -32,4 +34,6 @@
"PrefixRecord",
"PrefixPaths",
"SparseRepoData",
"solve",
"Platform",
]
3 changes: 3 additions & 0 deletions py-rattler/rattler/solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from rattler.solver.solver import solve

__all__ = ["solve"]
64 changes: 64 additions & 0 deletions py-rattler/rattler/solver/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations
from typing import List, Optional
from rattler.match_spec.match_spec import MatchSpec

from rattler.rattler import py_solve
from rattler.repo_data.record import RepoDataRecord
from rattler.repo_data.sparse import SparseRepoData
from rattler.virtual_package.generic import GenericVirtualPackage


def solve(
specs: List[MatchSpec],
available_packages: List[SparseRepoData],
locked_packages: Optional[List[RepoDataRecord]] = None,
pinned_packages: Optional[List[RepoDataRecord]] = None,
virtual_packages: Optional[List[GenericVirtualPackage]] = None,
) -> List[RepoDataRecord]:
"""
Resolve the dependencies and return the `RepoDataRecord`s
that should be present in the environment.

Arguments:
specs: A list of matchspec to solve.
available_packages: A list of RepoData to use for solving the `specs`.
locked_packages: Records of packages that are previously selected.
If the solver encounters multiple variants of a single
package (identified by its name), it will sort the records
and select the best possible version. However, if there
exists a locked version it will prefer that variant instead.
This is useful to reduce the number of packages that are
updated when installing new packages. Usually you add the
currently installed packages or packages from a lock-file here.
pinned_packages: Records of packages that are previously selected and CANNOT
be changed. If the solver encounters multiple variants of
a single package (identified by its name), it will sort the
records and select the best possible version. However, if
there is a variant available in the `pinned_packages` field it
will always select that version no matter what even if that
means other packages have to be downgraded.
virtual_packages: A list of virtual packages considered active.

Returns:
Resolved list of `RepoDataRecord`s.
"""

if not locked_packages:
locked_packages = list()

if not pinned_packages:
pinned_packages = list()

if not virtual_packages:
virtual_packages = list()
tarunps marked this conversation as resolved.
Show resolved Hide resolved

return [
RepoDataRecord._from_py_record(solved_package)
for solved_package in py_solve(
[spec._match_spec for spec in specs],
[package._sparse for package in available_packages],
[package._record for package in locked_packages],
[package._record for package in pinned_packages],
[v_package._generic_virtual_package for v_package in virtual_packages],
)
]
11 changes: 11 additions & 0 deletions py-rattler/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use rattler_conda_types::{
};
use rattler_repodata_gateway::fetch::FetchRepoDataError;
use rattler_shell::activation::ActivationError;
use rattler_solve::SolveError;
use rattler_virtual_packages::DetectVirtualPackageError;
use thiserror::Error;

Expand Down Expand Up @@ -38,6 +39,10 @@ pub enum PyRattlerError {
DetectVirtualPackageError(#[from] DetectVirtualPackageError),
#[error(transparent)]
IoError(#[from] io::Error),
#[error(transparent)]
SolverError(#[from] SolveError),
#[error("invalid 'SparseRepoData' object found")]
InvalidSparseDataError,
}

impl From<PyRattlerError> for PyErr {
Expand Down Expand Up @@ -69,6 +74,10 @@ impl From<PyRattlerError> for PyErr {
DetectVirtualPackageException::new_err(err.to_string())
}
PyRattlerError::IoError(err) => IoException::new_err(err.to_string()),
PyRattlerError::SolverError(err) => SolverException::new_err(err.to_string()),
PyRattlerError::InvalidSparseDataError => InvalidSparseDataException::new_err(
PyRattlerError::InvalidSparseDataError.to_string(),
),
}
}
}
Expand All @@ -85,3 +94,5 @@ create_exception!(exceptions, FetchRepoDataException, PyException);
create_exception!(exceptions, CacheDirException, PyException);
create_exception!(exceptions, DetectVirtualPackageException, PyException);
create_exception!(exceptions, IoException, PyException);
create_exception!(exceptions, SolverException, PyException);
create_exception!(exceptions, InvalidSparseDataException, PyException);
5 changes: 5 additions & 0 deletions py-rattler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod platform;
mod prefix_record;
mod repo_data;
mod shell;
mod solver;
mod version;
mod virtual_package;

Expand All @@ -34,6 +35,7 @@ use pyo3::prelude::*;

use platform::{PyArch, PyPlatform};
use shell::{PyActivationResult, PyActivationVariables, PyActivator, PyShellEnum};
use solver::py_solve;
use virtual_package::PyVirtualPackage;

#[pymodule]
Expand Down Expand Up @@ -71,6 +73,9 @@ fn rattler(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyPrefixRecord>().unwrap();
m.add_class::<PyPrefixPaths>().unwrap();

m.add_function(wrap_pyfunction!(py_solve, m).unwrap())
.unwrap();

// Exceptions
m.add(
"InvalidVersionError",
Expand Down
2 changes: 1 addition & 1 deletion py-rattler/src/match_spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
#[repr(transparent)]
#[derive(Clone)]
pub struct PyMatchSpec {
inner: MatchSpec,
pub(crate) inner: MatchSpec,
}

impl From<MatchSpec> for PyMatchSpec {
Expand Down
12 changes: 8 additions & 4 deletions py-rattler/src/repo_data/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ impl From<SparseRepoData> for PySparseRepoData {
}
}

impl<'a> From<&'a PySparseRepoData> for &'a SparseRepoData {
fn from(value: &'a PySparseRepoData) -> Self {
value.inner.as_ref()
}
}

#[pymethods]
impl PySparseRepoData {
#[new]
Expand Down Expand Up @@ -57,11 +63,9 @@ impl PySparseRepoData {
repo_data: Vec<PySparseRepoData>,
package_names: Vec<PyPackageName>,
) -> PyResult<Vec<Vec<PyRepoDataRecord>>> {
let repo_data = repo_data.iter().map(|r| r.inner.as_ref());
let package_names = package_names.into_iter().map(Into::into);

// release gil to allow other threads to progress
py.allow_threads(move || {
let repo_data = repo_data.iter().map(Into::into);
let package_names = package_names.into_iter().map(Into::into);
Ok(
SparseRepoData::load_records_recursive(repo_data, package_names, None)?
.into_iter()
Expand Down
Loading