Skip to content

Commit

Permalink
Merge pull request #6 from KIT-MRT/add_python_bindings
Browse files Browse the repository at this point in the history
Add python bindings
  • Loading branch information
ll-nick authored Dec 6, 2024
2 parents ec75a0f + 3d21e07 commit db9e97c
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 7 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ RUN apt-get update && \
apt-get install -y \
build-essential \
cmake \
libgtest-dev && \
libgtest-dev \
pybind11-dev \
python3-dev \
python3-pybind11 && \
apt-get clean


Expand Down
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ or by specifying one comparison policy and threshold (100ms for example), and re

More usage please check the unittest.

## Python bindings

The library can be used in Python via pybind11 bindings.
Since `util_caching` is a templated C++ library,
you need to explicitly instantiate the template for the types you want to use in Python.
For this, we provide convenience functions to bind the library for the desired types.
Simply call them in a pybind11 module definition, e.g.:

```cpp
PYBIND11_MODULE(util_caching, m) {
python_api::number_based::bindCache<double, double>(m);
}
```
and use them in Python:
```python
from util_caching import Cache
cache = Cache()
cache.cache(1.0, 2.0)
```
We re-implemented all of the C++ unit tests in Python, so take a closer look at those for more advanced usage examples.


## Installation

Expand Down Expand Up @@ -111,7 +133,8 @@ find_package(util_caching REQUIRED)
### Building from source using CMake

First make sure all dependencies are installed:
- [Googletest](https://github.com/google/googletest) (only if you want to build unit tests)
- [Googletest](https://github.com/google/googletest) (optional, if you want to build unit tests)
- [pybind11](https://pybind11.readthedocs.io/en/stable/) (optional, if you want to build Python bindings and unit tests)

See also the [`Dockerfile`](./Dockerfile) for how to install these packages under Debian or Ubuntu.

Expand Down
151 changes: 151 additions & 0 deletions include/util_caching/python_bindings.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#include <chrono>

#include <pybind11/chrono.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "cache.hpp"

namespace util_caching::python_api {

namespace py = pybind11;

namespace number_based {
namespace internal {

/*!
* \brief Bind the comparison policies to the Cache class
*
* This function binds the comparison policies to the Cache class. The policies
* are passed as variadic template arguments. The function overloads the
* `cached` function for each policy.
*/
template <typename CacheT, typename NumberT, typename... ComparisonPolicyTs>
void bindPolicies(py::class_<CacheT, std::shared_ptr<CacheT>> &cacheClass) {
(cacheClass.def(
"cached",
[](CacheT &self, const NumberT &key, const ComparisonPolicyTs &policy) {
return self.template cached<ComparisonPolicyTs>(key, policy);
},
py::arg("key"), py::arg("policy")),
...);
}
} // namespace internal

/*!
* \brief Bind the ApproximateNumber policy
*
* This function adds bindings for the ApproximateNumber policy to the given
* python module under the given name.
*/
template <typename NumberT>
void bindApproximatePolicy(py::module &module,
const std::string &name = "ApproximateNumber") {
using ApproximateNumberT = policies::ApproximateNumber<NumberT>;
py::class_<ApproximateNumberT, std::shared_ptr<ApproximateNumberT>>(
module, name.c_str())
.def(py::init<NumberT>(), py::arg("threshold"))
.def("__call__", &ApproximateNumberT::operator(), "Compare two numbers");
}

/*!
* \brief Bindings for a Cache that is based on number comparisons
*
* This function binds the Cache class for a specific number-based key type
* (NumberT) and value type (ValueT). Optionally, add a list of comparison
* policies to the list of template parameters. The `cached` function will be
* overloaded for each one of them. Call this function once inside
* PYBIND11_MODULE macro to create the bindings for the Cache class.
*/
template <typename NumberT, typename ValueT, typename... ComparisonPolicies>
void bindCache(py::module &module) {
using CacheT = Cache<NumberT, ValueT>;
py::class_<CacheT, std::shared_ptr<CacheT>> cache(module, "Cache");
cache
.def(py::init<>())
// We cannot pass template parameters to python functions, therefore we
// need to explicitly bind all instantiations to different python
// functions. We need to use the lambdas here to handle the seconds
// argument, defining the comparison policy.
.def(
"cached",
[](CacheT &self, const NumberT &key) {
return self.template cached<std::equal_to<NumberT>>(key);
},
py::arg("key"))
.def("cache", &CacheT::cache, py::arg("key"), py::arg("value"))
.def("reset", &CacheT::reset);

internal::bindPolicies<CacheT, NumberT, ComparisonPolicies...>(cache);
}

} // namespace number_based

namespace time_based {
namespace internal {

/*!
* \brief Bind the comparison policies to the Cache class
*
* This function binds the comparison policies to the Cache class. The policies
* are passed as variadic template arguments. The function overloads the
* `cached` function for each policy.
*/
template <typename CacheT, typename TimeT, typename... ComparisonPolicyTs>
void bindPolicies(py::class_<CacheT, std::shared_ptr<CacheT>> &cache) {
(cache.def(
"cached",
[](CacheT &self, const TimeT &key, const ComparisonPolicyTs &policy) {
return self.template cached<ComparisonPolicyTs>(key, policy);
},
py::arg("key"), py::arg("policy")),
...);
}
} // namespace internal

/*!
* \brief Bind the ApproximateTime policy
*
* This function adds bindings for the ApproximateTime policy to the given
* python module under the given name.
*/
template <typename TimeT, typename ThresholdTimeUnitT>
void bindApproximatePolicy(py::module &module,
const std::string &name = "ApproximateTime") {
using ApproximateTimeT = policies::ApproximateTime<TimeT, ThresholdTimeUnitT>;
py::class_<ApproximateTimeT, std::shared_ptr<ApproximateTimeT>>(module,
name.c_str())
.def(py::init<double>(), py::arg("threshold"))
.def("__call__", &ApproximateTimeT::operator(),
"Compare two time points");
}

/*!
* \brief Bindings for a Cache that is based on time comparisons.
*
* This function binds the Cache class for a specific time-based key type
* (TimeT) and value type (ValueT). Optionally, add a list of comparison
* policies to the list of template parameters. The `cached` function will be
* overloaded for each one of them. Call this function once inside
* PYBIND11_MODULE macro to create the bindings for the Cache class.
*/
template <typename TimeT, typename ValueT, typename... ComparisonPolicyTs>
void bindCache(py::module &module) {
using CacheT = Cache<TimeT, ValueT>;

py::class_<CacheT, std::shared_ptr<CacheT>> cache(module, "Cache");
cache.def(py::init<>())
.def(
"cached",
[](CacheT &self, const TimeT &key) {
return self.template cached<std::equal_to<TimeT>>(key);
},
py::arg("key"))
.def("cache", &CacheT::cache, py::arg("key"), py::arg("value"))
.def("reset", &CacheT::reset);

internal::bindPolicies<CacheT, TimeT, ComparisonPolicyTs...>(cache);
}

} // namespace time_based
} // namespace util_caching::python_api
52 changes: 47 additions & 5 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,26 @@ endif()
###################

find_package(GTest)
find_package(pybind11 CONFIG)

if(NOT GTEST_FOUND AND NOT pybind11_FOUND)
message(WARNING "Neither GTest nor pybind11 found. Cannot compile tests!")
endif()

# Find installed lib and its dependencies, if this is build as top-level project
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
find_package(util_caching REQUIRED)
endif()


###########
## Build ##
###########
####################
## C++ Unit Tests ##
####################

if(GTEST_FOUND)
file(GLOB_RECURSE _tests CONFIGURE_DEPENDS "*.cpp" "*.cc")
list(FILTER _tests EXCLUDE REGEX "${CMAKE_CURRENT_BINARY_DIR}")
list(REMOVE_ITEM _tests "${CMAKE_CURRENT_SOURCE_DIR}/python_bindings.cpp")

foreach(_test ${_tests})
get_filename_component(_test_name ${_test} NAME_WE)
Expand All @@ -80,6 +86,42 @@ if(GTEST_FOUND)
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
)
endforeach()
else()
message(WARNING "GTest not found. Cannot compile tests!")
endif()

#######################
## Python Unit Tests ##
#######################

if(pybind11_FOUND)
# Find Python3 to run tests via ctest
find_package(Python3 REQUIRED)

# Python bindings modules
pybind11_add_module(util_caching_py
python_bindings.cpp
)
target_link_libraries(util_caching_py PUBLIC
util_caching
)

file(GLOB_RECURSE _py_tests CONFIGURE_DEPENDS "*.py")

# Copy Python test files to build directory
foreach(_py_test ${_py_tests})
get_filename_component(_py_test_name ${_py_test} NAME)
string(REGEX REPLACE "-test" "" PY_TEST_NAME ${_py_test_name})
set(PY_TEST_NAME ${PROJECT_NAME}-pytest-${PY_TEST_NAME})

message(STATUS
"Adding python unittest \"${PY_TEST_NAME}\" with working dir ${PROJECT_SOURCE_DIR}/${TEST_FOLDER} \n _test: ${_py_test}"
)

configure_file(${_py_test} ${PY_TEST_NAME} COPYONLY)

add_test(NAME ${PY_TEST_NAME}
COMMAND ${Python3_EXECUTABLE} ${PY_TEST_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
)
endforeach()
endif()

60 changes: 60 additions & 0 deletions test/python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include <pybind11/pybind11.h>

#include "cache.hpp"
#include "python_bindings.hpp"
#include "types.hpp"

namespace py = pybind11;

using namespace util_caching;

/*!
* \brief A policy that always returns true
*
* Custom policies have to be defined in C++ and then bound to Python.
* To overload the `cache` function, the policy has to be passed as a template parameter to the `bindCache` function.
*/
struct SomePolicyWithoutParams {
SomePolicyWithoutParams() = default;
bool operator()(const Time& /*lhs*/, const Time& /*rhs*/) const {
return true;
}
};

/*!
* \brief The python module definition that allows running python unit tests equivalent to the native C++ ones.
*/
PYBIND11_MODULE(util_caching_py, mainModule) {
// Just some aliases to make the code more readable
using ApproximateNumberT = policies::ApproximateNumber<double>;
using ApproximateTimeT = policies::ApproximateTime<Time, std::chrono::milliseconds>;
using ApproximateTimeSecondsT = policies::ApproximateTime<Time, std::chrono::seconds>;

// Since we want to use this policy in python, we need to be able to instatiate it there
py::class_<SomePolicyWithoutParams, std::shared_ptr<SomePolicyWithoutParams>>(mainModule, "SomePolicyWithoutParams")
.def(py::init<>())
.def("__call__", &SomePolicyWithoutParams::operator());

// Adding a submodule is optional but a good way to structure the bindings
py::module numberBased = mainModule.def_submodule("number_based");
// If we want to use a policy, we need to bind it. For the builtin policies, we can use this convenience function.
python_api::number_based::bindApproximatePolicy<double>(numberBased);
// The core binding, the cache class itself.
python_api::number_based::bindCache<double, // KeyType
double, // Value type
ApproximateNumberT // Optionally, a list of comparison policies, each one will
// overload the `cached` function
>(numberBased);

// Same as above, but for the time-based cache
py::module timeBased = mainModule.def_submodule("time_based");
// We can bind the builtin comparison policy for different time units but then we have to name them differently
python_api::time_based::bindApproximatePolicy<Time, std::chrono::milliseconds>(timeBased, "ApproximateTime");
python_api::time_based::bindApproximatePolicy<Time, std::chrono::seconds>(timeBased, "ApproximateTimeSeconds");
// The core binding, the cache class itself.
python_api::time_based::bindCache<Time, // Key type
double, // Value type
ApproximateTimeT, // A list of all comparison policies we intend to use
ApproximateTimeSecondsT,
SomePolicyWithoutParams>(timeBased);
}
Loading

0 comments on commit db9e97c

Please sign in to comment.