Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize prediction cache. #8783

Merged
merged 5 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions include/xgboost/cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#ifndef XGBOOST_CACHE_H_
#define XGBOOST_CACHE_H_

#include <xgboost/logging.h> // CHECK_EQ

#include <cstddef> // std::size_t
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
#include <queue> // std:queue
#include <unordered_map> // std::unordered_map
#include <vector> // std::vector

namespace xgboost {
class DMatrix;
/**
* \brief FIFO cache for DMatrix related data.
*
* \tparam CacheT The type that needs to be cached.
*/
template <typename CacheT>
class DMatrixCache {
public:
struct Item {
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr<DMatrix> ref;
// The cached item
std::shared_ptr<CacheT> value;

CacheT const& Value() const { return *value; }
CacheT& Value() { return *value; }
};

protected:
std::unordered_map<DMatrix const*, Item> container_;
std::queue<DMatrix const*> queue_;
std::size_t max_size_;

void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }

void ClearExpired() {
// Clear expired entries
this->CheckConsistent();
std::vector<DMatrix const*> expired;
std::queue<DMatrix const*> remained;

while (!queue_.empty()) {
auto p_fmat = queue_.front();
auto it = container_.find(p_fmat);
CHECK(it != container_.cend());
if (it->second.ref.expired()) {
expired.push_back(it->first);
} else {
remained.push(it->first);
}
queue_.pop();
}
CHECK(queue_.empty());
CHECK_EQ(remained.size() + expired.size(), container_.size());

for (auto const* p_fmat : expired) {
container_.erase(p_fmat);
}
while (!remained.empty()) {
auto p_fmat = remained.front();
queue_.push(p_fmat);
remained.pop();
}
this->CheckConsistent();
}

void ClearExcess() {
this->CheckConsistent();
while (queue_.size() >= max_size_) {
auto p_fmat = queue_.front();
queue_.pop();
container_.erase(p_fmat);
}
this->CheckConsistent();
}

public:
/**
* \param cache_size Maximum size of the cache.
*/
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
/**
* \brief Cache a new DMatrix if it's no in the cache already.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
* cache is tied to the shared pointer.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param args The arguments for constructing a new cache item, if needed.
*
* \return The cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
template <typename... Args>
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
CHECK(m);
this->ClearExpired();
if (container_.size() >= max_size_) {
this->ClearExcess();
}
// after clear, cache size < max_size
CHECK_LT(container_.size(), max_size_);
auto it = container_.find(m.get());
if (it == container_.cend()) {
// after the new DMatrix, cache size is at most max_size
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
queue_.push(m.get());
}
return container_.at(m.get()).value;
}
/**
* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container() {
this->ClearExpired();
return container_;
}

std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
CHECK(container_.find(m) != container_.cend());
CHECK(!container_.at(m).ref.expired());
return container_.at(m).value;
}
};
} // namespace xgboost
#endif // XGBOOST_CACHE_H_
5 changes: 2 additions & 3 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023 by XGBoost Contributors
* \file gbm.h
* \brief Interface of gradient booster,
* that learns through gradient statistics.
Expand Down Expand Up @@ -31,7 +31,6 @@ class ObjFunction;
struct Context;
struct LearnerModelParam;
struct PredictionCacheEntry;
class PredictionContainer;

/*!
* \brief interface of gradient boosting model.
Expand Down
28 changes: 5 additions & 23 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module.
Expand All @@ -8,12 +8,13 @@
#ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_

#include <dmlc/io.h> // Serializable
#include <xgboost/base.h>
#include <xgboost/context.h> // Context
#include <xgboost/feature_map.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h> // Tensor
#include <xgboost/model.h>
#include <xgboost/predictor.h>
#include <xgboost/task.h>

#include <map>
Expand All @@ -29,6 +30,7 @@ class GradientBooster;
class ObjFunction;
class DMatrix;
class Json;
struct XGBAPIThreadLocalEntry;

enum class PredictionType : std::uint8_t { // NOLINT
kValue = 0,
Expand All @@ -40,26 +42,6 @@ enum class PredictionType : std::uint8_t { // NOLINT
kLeaf = 6
};

/*! \brief entry to to easily hold returning information */
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning raw buffer */
std::vector<char> ret_char_vec;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returning prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returning prediction shape. */
std::vector<bst_ulong> prediction_shape;
};

/*!
* \brief Learner class that does training and prediction.
* This is the user facing module of xgboost training.
Expand Down
77 changes: 24 additions & 53 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
@@ -1,95 +1,66 @@
/*!
* Copyright 2017-2022 by Contributors
/**
* Copyright 2017-2023 by Contributors
* \file predictor.h
* \brief Interface of predictor,
* performs predictions for a gradient booster.
*/
#pragma once
#include <xgboost/base.h>
#include <xgboost/cache.h> // DMatrixCache
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>

#include <functional>
#include <functional> // std::function
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

// Forward declarations
namespace xgboost {
class TreeUpdater;
namespace gbm {
struct GBTreeModel;
} // namespace gbm
}
} // namespace xgboost

namespace xgboost {
/**
* \struct PredictionCacheEntry
*
* \brief Contains pointer to input matrix and associated cached predictions.
*/
struct PredictionCacheEntry {
// A storage for caching prediction values
HostDeviceVector<bst_float> predictions;
// The version of current cache, corresponding number of layers of trees
uint32_t version { 0 };
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr< DMatrix > ref;
std::uint32_t version{0};

PredictionCacheEntry() = default;
/* \brief Update the cache entry by number of versions.
/**
* \brief Update the cache entry by number of versions.
*
* \param v Added versions.
*/
void Update(uint32_t v) {
void Update(std::uint32_t v) {
version += v;
}
};

/* \brief A container for managed prediction caches.
/**
* \brief A container for managed prediction caches.
*/
class PredictionContainer {
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
void ClearExpiredEntries();
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
// we cache up to 32 DMatrix
std::size_t static constexpr DefaultSize() { return 32; }

public:
PredictionContainer() = default;
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out
* all expired caches by checking the `std::weak_ptr`. Caching an existing
* DMatrix won't renew it.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
* cache is tied to the shared pointer.
*
* Another way to make a safe cache is create a proxy to this entry, with anther shared
* pointer defined inside, and pass this proxy around instead of the real entry. But
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param device Which device should the cache be allocated on. Pass
* Context::kCpuId for CPU or positive integer for GPU id.
*
* \return the cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
* method. Otherwise a dmlc::Error is thrown.
*
* \param m pointer to the DMatrix.
* \return The prediction cache for passed in DMatrix.
*/
PredictionCacheEntry& Entry(DMatrix* m);
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container();
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->CacheItem(m);
auto p_cache = this->container_.find(m.get());
if (device != Context::kCpuId) {
p_cache->second.Value().predictions.SetDevice(device);
}
return p_cache->second.Value();
}
};

/**
Expand All @@ -114,7 +85,7 @@ class Predictor {
*
* \param cfg The configuration.
*/
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
virtual void Configure(Args const&);

/**
* \brief Initialize output prediction
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
#include <vector>

#include "../collective/communicator-inl.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/charconv.h"
#include "../common/io.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "c_api_error.h"
#include "c_api_utils.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
Expand Down
1 change: 1 addition & 0 deletions src/c_api/c_api.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/**
* Copyright 2019-2023 by XGBoost Contributors
*/
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/threading_utils.h"
#include "../data/device_adapter.cuh"
#include "../data/proxy_dmatrix.h"
Expand Down
35 changes: 35 additions & 0 deletions src/common/api_entry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Copyright 2016-2023 by XGBoost contributors
*/
#ifndef XGBOOST_COMMON_API_ENTRY_H_
#define XGBOOST_COMMON_API_ENTRY_H_
#include <string> // std::string
#include <vector> // std::vector

#include "xgboost/base.h" // GradientPair,bst_ulong
#include "xgboost/predictor.h" // PredictionCacheEntry

namespace xgboost {
/**
* \brief entry to to easily hold returning information
*/
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning raw buffer */
std::vector<char> ret_char_vec;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returning prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returning prediction shape. */
std::vector<bst_ulong> prediction_shape;
};
} // namespace xgboost
#endif // XGBOOST_COMMON_API_ENTRY_H_
Loading