Skip to content

Commit

Permalink
slight refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
JohanMabille committed Oct 18, 2023
1 parent 2522ebf commit 46786c5
Showing 1 changed file with 170 additions and 98 deletions.
268 changes: 170 additions & 98 deletions libmamba/src/core/transaction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,54 +828,172 @@ namespace mamba
add_json(to_unlink, "UNLINK");
}

bool MTransaction::fetch_extract_packages()
namespace
{
// TODO: move this to the PackageDownloadMonitor
auto& pbar_manager = Console::instance().init_progress_bar_manager(ProgressBarMode::aggregated
);
using FetcherList = std::vector<PackageFetcher>;
// Free functions instead of private method to avoid exposing downloaders
// and package fetchers in the header. Ideally we may want a pimpl or
// a private implementation header when we refactor this class.
FetcherList
build_fetchers(MPool& pool, const Solution& solution, MultiPackageCache& multi_cache)
{
FetcherList fetchers;
auto& channel_context = pool.channel_context();
auto& ctx = channel_context.context();

auto& channel_context = m_pool.channel_context();
auto& ctx = channel_context.context();
PackageFetcherSemaphore::set_max(ctx.threads_params.extract_threads);
if (ctx.experimental && ctx.validation_params.verify_artifacts)
{
LOG_INFO << "Content trust is enabled, package(s) signatures will be verified";
}
for_each_to_install(
solution.actions,
[&](const auto& pkg)
{
if (ctx.experimental && ctx.validation_params.verify_artifacts)
{
const auto& repo_checker = channel_context.make_channel(pkg.channel)
.repo_checker(ctx, multi_cache);
repo_checker.verify_package(
pkg.json_signable(),
nlohmann::json::parse(pkg.signatures)
);

LOG_DEBUG << "'" << pkg.name << "' trusted from '" << pkg.channel << "'";
}
fetchers.emplace_back(pkg, channel_context, multi_cache);
}
);

if (ctx.experimental && ctx.validation_params.verify_artifacts)
{
auto out = Console::stream();
fmt::print(
out,
"Content trust verifications successful, {} ",
fmt::styled("package(s) are trusted", ctx.graphics_params.palette.safe)
);
LOG_INFO << "All package(s) are trusted";
}
return fetchers;
}

using ExtractRequestList = std::vector<PackageExtractRequest>;

// TODO: move to private method create_fetchers
if (ctx.experimental && ctx.validation_params.verify_artifacts)
ExtractRequestList
build_extract_requests(const Context& context, FetcherList& fetchers, std::size_t extract_size)
{
LOG_INFO << "Content trust is enabled, package(s) signatures will be verified";
auto extract_options = ExtractOptions::from_context(context);
ExtractRequestList extract_requests;
extract_requests.reserve(extract_size);
std::transform(
fetchers.begin(),
fetchers.begin() + static_cast<std::ptrdiff_t>(extract_size),
std::back_inserter(extract_requests),
[extract_options](auto& f) { return f.build_extract_request(extract_options); }
);
return extract_requests;
}

std::vector<PackageFetcher> fetchers;
using ExtractTaskList = std::vector<std::future<void>>;

for_each_to_install(
m_solution.actions,
[&](const auto& pkg)
MultiDownloadRequest build_download_requests(
FetcherList& fetchers,
ExtractRequestList& extract_requests,
ExtractTaskList& extract_tasks,
std::size_t download_size
)
{
MultiDownloadRequest download_requests;
download_requests.reserve(download_size);
for (auto [fit, eit] = std::tuple{ fetchers.begin(), extract_requests.begin() };
fit != fetchers.begin() + static_cast<std::ptrdiff_t>(download_size);
++fit, ++eit)
{
if (ctx.experimental && ctx.validation_params.verify_artifacts)
{
const auto& repo_checker = channel_context.make_channel(pkg.channel)
.repo_checker(ctx, m_multi_cache);
repo_checker.verify_package(
pkg.json_signable(),
nlohmann::json::parse(pkg.signatures)
);
auto ceit = eit; // Apple Clang cannot capture eit
auto task = std::make_shared<std::packaged_task<void(std::size_t)>>(
[ceit](std::size_t downloaded_size) { ceit->run(downloaded_size); }
);
extract_tasks.push_back(task->get_future());
download_requests.push_back(fit->build_download_request(
[extract_task = std::move(task)](std::size_t downloaded_size)
{
MainExecutor::instance().schedule(
[t = std::move(extract_task)](std::size_t ds) { (*t)(ds); },
downloaded_size
);
}
));
}
return download_requests;
}

LOG_DEBUG << "'" << pkg.name << "' trusted from '" << pkg.channel << "'";
}
fetchers.emplace_back(pkg, m_pool.channel_context(), m_multi_cache);
void schedule_extractions(
ExtractRequestList& extract_requests,
ExtractTaskList& extract_tasks,
std::size_t download_size
)
{
for (auto it = extract_requests.begin() + static_cast<std::ptrdiff_t>(download_size);
it != extract_requests.end();
++it)
{
std::packaged_task task{ [=] { it->run(); } };
extract_tasks.push_back(task.get_future());
MainExecutor::instance().schedule(std::move(task));
}
);
}

if (ctx.experimental && ctx.validation_params.verify_artifacts)
bool trigger_download(
MultiDownloadRequest requests,
const Context& context,
DownloadOptions options,
PackageDownloadMonitor* monitor
)
{
auto out = Console::stream();
fmt::print(
out,
"Content trust verifications successful, {} ",
fmt::styled("package(s) are trusted", ctx.graphics_params.palette.safe)
auto result = download(std::move(requests), context, options, monitor);
bool all_downloaded = std::accumulate(
result.begin(),
result.end(),
true,
[](bool acc, const auto& r) { return acc && r; }
);
LOG_INFO << "All package(s) are trusted";
return all_downloaded;
}

bool check_all_valid(const FetcherList& fetchers, const ExtractRequestList& extract_requests)
{
bool all_valid = true;
for (auto [fit, eit] = std::tuple{ fetchers.begin(), extract_requests.begin() };
eit != extract_requests.end();
++fit, ++eit)
{
PackageExtractRequest::Result res = eit->get_result();
if (!res.valid || !res.extracted)
{
fit->clear_cache();
all_valid = false;
// TODO: check if we can remove this
throw std::runtime_error(
std::string("Found incorrect download: ") + fit->name() + ". Aborting"
);
}
}
return all_valid;
}
}

bool MTransaction::fetch_extract_packages()
{
// TODO: move this to the PackageDownloadMonitor
auto& pbar_manager = Console::instance().init_progress_bar_manager(ProgressBarMode::aggregated
);

auto& channel_context = m_pool.channel_context();
auto& ctx = channel_context.context();
PackageFetcherSemaphore::set_max(ctx.threads_params.extract_threads);

FetcherList fetchers = build_fetchers(m_pool, m_solution, m_multi_cache);

auto download_end = std::partition(
fetchers.begin(),
fetchers.end(),
Expand All @@ -886,46 +1004,25 @@ namespace mamba
fetchers.end(),
[](const auto& f) { return f.needs_extract(); }
);

auto download_size = static_cast<std::size_t>(std::distance(fetchers.begin(), download_end));
auto extract_size = static_cast<std::size_t>(std::distance(fetchers.begin(), extract_end));

// At this point:
// - [fetchers.begin(), download_end) contains packages that need to be downloaded,
// validated and extracted
// - [download_end, extract_end) contains packages that need to be extracted only
// - [extract_end, fecthers.end()) contains packages already installed and extracted

auto extract_options = ExtractOptions::from_context(ctx);
std::vector<PackageExtractRequest> extract_requests;
extract_requests.reserve(std::distance(fetchers.begin(), extract_end));
std::transform(
fetchers.begin(),
extract_end,
std::back_inserter(extract_requests),
[extract_options](auto& f) { return f.build_extract_request(extract_options); }
);
// Tracks extraction requests
std::vector<std::future<void>> extract_tasks;
ExtractRequestList extract_requests = build_extract_requests(ctx, fetchers, extract_size);
ExtractTaskList extract_tasks;
extract_tasks.reserve(extract_requests.size());

std::vector<DownloadRequest> download_requests;
download_requests.reserve(std::distance(fetchers.begin(), download_end));
for (auto [fit, eit] = std::tuple{ fetchers.begin(), extract_requests.begin() };
fit != download_end;
++fit, ++eit)
{
auto ceit = eit; // Apple Clang cannot capture eit
auto task = std::make_shared<std::packaged_task<void(std::size_t)>>(
[ceit](std::size_t downloaded_size) { ceit->run(downloaded_size); }
);
extract_tasks.push_back(task->get_future());
download_requests.push_back(fit->build_download_request(
[extract_task = std::move(task)](std::size_t downloaded_size)
{
MainExecutor::instance().schedule(
[t = std::move(extract_task)](std::size_t ds) { (*t)(ds); },
downloaded_size
);
}
));
}
MultiDownloadRequest download_requests = build_download_requests(
fetchers,
extract_requests,
extract_tasks,
download_size
);

std::unique_ptr<PackageDownloadMonitor> monitor = nullptr;
DownloadOptions download_options{ true, true };
Expand All @@ -935,23 +1032,13 @@ namespace mamba
monitor->observe(download_requests, extract_requests, download_options);
}

for (auto it = extract_requests.begin() + download_requests.size();
it != extract_requests.end();
++it)
{
std::packaged_task task{ [=] { it->run(); } };
extract_tasks.push_back(task.get_future());
MainExecutor::instance().schedule(std::move(task));
}

auto result = download(std::move(download_requests), ctx, download_options, monitor.get());
bool all_downloaded = std::accumulate(
result.begin(),
result.end(),
true,
[](bool acc, const auto& r) { return acc && r; }
schedule_extractions(extract_requests, extract_tasks, download_size);
bool all_downloaded = trigger_download(
std::move(download_requests),
ctx,
download_options,
monitor.get()
);

if (!all_downloaded)
{
LOG_ERROR << "Download didn't finish!";
Expand All @@ -964,22 +1051,7 @@ namespace mamba
task.wait();
}

bool all_valid = true;
for (auto [fit, eit] = std::tuple{ fetchers.begin(), extract_requests.begin() };
eit != extract_requests.end();
++fit, ++eit)
{
PackageExtractRequest::Result res = eit->get_result();
if (!res.valid || !res.extracted)
{
fit->clear_cache();
all_valid = false;
// TODO: check if we can remove this
throw std::runtime_error(
std::string("Found incorrect download: ") + fit->name() + ". Aborting"
);
}
}
bool all_valid = check_all_valid(fetchers, extract_requests);
return !is_sig_interrupted() && all_valid;
}

Expand Down

0 comments on commit 46786c5

Please sign in to comment.