Skip to content

Commit

Permalink
Replace pybind11 with nanobind at the C++ module file
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed Dec 18, 2024
1 parent 6514903 commit 69941d2
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions frontend/catalyst/third_party/oqc/src/oqc_python_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <pybind11/eval.h>
#include <pybind11/pybind11.h>
#include <string>
#include <vector>

#include <nanobind/eval.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>

#include "Exception.hpp"

Expand Down Expand Up @@ -46,28 +49,34 @@ except Exception as e:
[[gnu::visibility("default")]] void counts(const char *_circuit, const char *_device, size_t shots,
size_t num_qubits, const char *_kwargs, void *_vector)
{
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;

py::gil_scoped_acquire lock;
nb::gil_scoped_acquire lock;

auto locals = py::dict("circuit"_a = _circuit, "device"_a = _device, "kwargs"_a = _kwargs,
"shots"_a = shots, "msg"_a = "");
nb::dict locals;
locals["circuit"] = _circuit;
locals["device"] = _device;
locals["kwargs"] = _kwargs;
locals["shots"] = shots;
locals["msg"] = "";

py::exec(program, py::globals(), locals);
// Evaluate in scope of main module
nb::object scope = nb::module_::import_("__main__").attr("__dict__");
nb::exec(nb::str(program.c_str()), scope, locals);

auto &&msg = locals["msg"].cast<std::string>();
auto msg = nb::cast<std::string>(locals["msg"]);
RT_FAIL_IF(!msg.empty(), msg.c_str());

py::dict results = locals["counts"];
nb::dict results = locals["counts"];

std::vector<size_t> *counts_value = reinterpret_cast<std::vector<size_t> *>(_vector);
for (auto item : results) {
auto key = item.first;
auto value = item.second;
counts_value->push_back(value.cast<size_t>());
counts_value->push_back(nb::cast<size_t>(value));
}
return;
}

PYBIND11_MODULE(oqc_python_module, m) { m.doc() = "oqc"; }
NB_MODULE(oqc_python_module, m) { m.doc() = "oqc"; }

0 comments on commit 69941d2

Please sign in to comment.