-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from KIT-MRT/add_python_bindings
Add python bindings
- Loading branch information
Showing
6 changed files
with
381 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
#include <chrono> | ||
|
||
#include <pybind11/chrono.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "cache.hpp" | ||
|
||
namespace util_caching::python_api { | ||
|
||
namespace py = pybind11; | ||
|
||
namespace number_based { | ||
namespace internal { | ||
|
||
/*! | ||
* \brief Bind the comparison policies to the Cache class | ||
* | ||
* This function binds the comparison policies to the Cache class. The policies | ||
* are passed as variadic template arguments. The function overloads the | ||
* `cached` function for each policy. | ||
*/ | ||
template <typename CacheT, typename NumberT, typename... ComparisonPolicyTs> | ||
void bindPolicies(py::class_<CacheT, std::shared_ptr<CacheT>> &cacheClass) { | ||
(cacheClass.def( | ||
"cached", | ||
[](CacheT &self, const NumberT &key, const ComparisonPolicyTs &policy) { | ||
return self.template cached<ComparisonPolicyTs>(key, policy); | ||
}, | ||
py::arg("key"), py::arg("policy")), | ||
...); | ||
} | ||
} // namespace internal | ||
|
||
/*! | ||
* \brief Bind the ApproximateNumber policy | ||
* | ||
* This function adds bindings for the ApproximateNumber policy to the given | ||
* python module under the given name. | ||
*/ | ||
template <typename NumberT> | ||
void bindApproximatePolicy(py::module &module, | ||
const std::string &name = "ApproximateNumber") { | ||
using ApproximateNumberT = policies::ApproximateNumber<NumberT>; | ||
py::class_<ApproximateNumberT, std::shared_ptr<ApproximateNumberT>>( | ||
module, name.c_str()) | ||
.def(py::init<NumberT>(), py::arg("threshold")) | ||
.def("__call__", &ApproximateNumberT::operator(), "Compare two numbers"); | ||
} | ||
|
||
/*! | ||
* \brief Bindings for a Cache that is based on number comparisons | ||
* | ||
* This function binds the Cache class for a specific number-based key type | ||
* (NumberT) and value type (ValueT). Optionally, add a list of comparison | ||
* policies to the list of template parameters. The `cached` function will be | ||
* overloaded for each one of them. Call this function once inside | ||
* PYBIND11_MODULE macro to create the bindings for the Cache class. | ||
*/ | ||
template <typename NumberT, typename ValueT, typename... ComparisonPolicies> | ||
void bindCache(py::module &module) { | ||
using CacheT = Cache<NumberT, ValueT>; | ||
py::class_<CacheT, std::shared_ptr<CacheT>> cache(module, "Cache"); | ||
cache | ||
.def(py::init<>()) | ||
// We cannot pass template parameters to python functions, therefore we | ||
// need to explicitly bind all instantiations to different python | ||
// functions. We need to use the lambdas here to handle the seconds | ||
// argument, defining the comparison policy. | ||
.def( | ||
"cached", | ||
[](CacheT &self, const NumberT &key) { | ||
return self.template cached<std::equal_to<NumberT>>(key); | ||
}, | ||
py::arg("key")) | ||
.def("cache", &CacheT::cache, py::arg("key"), py::arg("value")) | ||
.def("reset", &CacheT::reset); | ||
|
||
internal::bindPolicies<CacheT, NumberT, ComparisonPolicies...>(cache); | ||
} | ||
|
||
} // namespace number_based | ||
|
||
namespace time_based { | ||
namespace internal { | ||
|
||
/*! | ||
* \brief Bind the comparison policies to the Cache class | ||
* | ||
* This function binds the comparison policies to the Cache class. The policies | ||
* are passed as variadic template arguments. The function overloads the | ||
* `cached` function for each policy. | ||
*/ | ||
template <typename CacheT, typename TimeT, typename... ComparisonPolicyTs> | ||
void bindPolicies(py::class_<CacheT, std::shared_ptr<CacheT>> &cache) { | ||
(cache.def( | ||
"cached", | ||
[](CacheT &self, const TimeT &key, const ComparisonPolicyTs &policy) { | ||
return self.template cached<ComparisonPolicyTs>(key, policy); | ||
}, | ||
py::arg("key"), py::arg("policy")), | ||
...); | ||
} | ||
} // namespace internal | ||
|
||
/*! | ||
* \brief Bind the ApproximateTime policy | ||
* | ||
* This function adds bindings for the ApproximateTime policy to the given | ||
* python module under the given name. | ||
*/ | ||
template <typename TimeT, typename ThresholdTimeUnitT> | ||
void bindApproximatePolicy(py::module &module, | ||
const std::string &name = "ApproximateTime") { | ||
using ApproximateTimeT = policies::ApproximateTime<TimeT, ThresholdTimeUnitT>; | ||
py::class_<ApproximateTimeT, std::shared_ptr<ApproximateTimeT>>(module, | ||
name.c_str()) | ||
.def(py::init<double>(), py::arg("threshold")) | ||
.def("__call__", &ApproximateTimeT::operator(), | ||
"Compare two time points"); | ||
} | ||
|
||
/*! | ||
* \brief Bindings for a Cache that is based on time comparisons. | ||
* | ||
* This function binds the Cache class for a specific time-based key type | ||
* (TimeT) and value type (ValueT). Optionally, add a list of comparison | ||
* policies to the list of template parameters. The `cached` function will be | ||
* overloaded for each one of them. Call this function once inside | ||
* PYBIND11_MODULE macro to create the bindings for the Cache class. | ||
*/ | ||
template <typename TimeT, typename ValueT, typename... ComparisonPolicyTs> | ||
void bindCache(py::module &module) { | ||
using CacheT = Cache<TimeT, ValueT>; | ||
|
||
py::class_<CacheT, std::shared_ptr<CacheT>> cache(module, "Cache"); | ||
cache.def(py::init<>()) | ||
.def( | ||
"cached", | ||
[](CacheT &self, const TimeT &key) { | ||
return self.template cached<std::equal_to<TimeT>>(key); | ||
}, | ||
py::arg("key")) | ||
.def("cache", &CacheT::cache, py::arg("key"), py::arg("value")) | ||
.def("reset", &CacheT::reset); | ||
|
||
internal::bindPolicies<CacheT, TimeT, ComparisonPolicyTs...>(cache); | ||
} | ||
|
||
} // namespace time_based | ||
} // namespace util_caching::python_api |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#include <pybind11/pybind11.h> | ||
|
||
#include "cache.hpp" | ||
#include "python_bindings.hpp" | ||
#include "types.hpp" | ||
|
||
namespace py = pybind11; | ||
|
||
using namespace util_caching; | ||
|
||
/*! | ||
* \brief A policy that always returns true | ||
* | ||
* Custom policies have to be defined in C++ and then bound to Python. | ||
* To overload the `cache` function, the policy has to be passed as a template parameter to the `bindCache` function. | ||
*/ | ||
struct SomePolicyWithoutParams { | ||
SomePolicyWithoutParams() = default; | ||
bool operator()(const Time& /*lhs*/, const Time& /*rhs*/) const { | ||
return true; | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief The python module definition that allows running python unit tests equivalent to the native C++ ones. | ||
*/ | ||
PYBIND11_MODULE(util_caching_py, mainModule) { | ||
// Just some aliases to make the code more readable | ||
using ApproximateNumberT = policies::ApproximateNumber<double>; | ||
using ApproximateTimeT = policies::ApproximateTime<Time, std::chrono::milliseconds>; | ||
using ApproximateTimeSecondsT = policies::ApproximateTime<Time, std::chrono::seconds>; | ||
|
||
// Since we want to use this policy in python, we need to be able to instatiate it there | ||
py::class_<SomePolicyWithoutParams, std::shared_ptr<SomePolicyWithoutParams>>(mainModule, "SomePolicyWithoutParams") | ||
.def(py::init<>()) | ||
.def("__call__", &SomePolicyWithoutParams::operator()); | ||
|
||
// Adding a submodule is optional but a good way to structure the bindings | ||
py::module numberBased = mainModule.def_submodule("number_based"); | ||
// If we want to use a policy, we need to bind it. For the builtin policies, we can use this convenience function. | ||
python_api::number_based::bindApproximatePolicy<double>(numberBased); | ||
// The core binding, the cache class itself. | ||
python_api::number_based::bindCache<double, // KeyType | ||
double, // Value type | ||
ApproximateNumberT // Optionally, a list of comparison policies, each one will | ||
// overload the `cached` function | ||
>(numberBased); | ||
|
||
// Same as above, but for the time-based cache | ||
py::module timeBased = mainModule.def_submodule("time_based"); | ||
// We can bind the builtin comparison policy for different time units but then we have to name them differently | ||
python_api::time_based::bindApproximatePolicy<Time, std::chrono::milliseconds>(timeBased, "ApproximateTime"); | ||
python_api::time_based::bindApproximatePolicy<Time, std::chrono::seconds>(timeBased, "ApproximateTimeSeconds"); | ||
// The core binding, the cache class itself. | ||
python_api::time_based::bindCache<Time, // Key type | ||
double, // Value type | ||
ApproximateTimeT, // A list of all comparison policies we intend to use | ||
ApproximateTimeSecondsT, | ||
SomePolicyWithoutParams>(timeBased); | ||
} |
Oops, something went wrong.