diff --git a/src/simsopt/field/magneticfieldclasses.py b/src/simsopt/field/magneticfieldclasses.py index 286d436fa..474217b2b 100644 --- a/src/simsopt/field/magneticfieldclasses.py +++ b/src/simsopt/field/magneticfieldclasses.py @@ -489,6 +489,7 @@ def __init__(self, field, degree, rrange, phirange, zrange, extrapolate=True, nf logger.warning(fr"Sure about phirange[1]={phirange[1]}? When exploiting rotational symmetry, the interpolant is never evaluated for phi>2\pi/nfp.") sopp.InterpolatedField.__init__(self, field, degree, rrange, phirange, zrange, extrapolate, nfp, stellsym) + self.__field == field def to_vtk(self, filename, h=0.1): """Export the field evaluated on a regular grid for visualisation with e.g. Paraview.""" diff --git a/src/simsoptpp/py_shared_ptr.h b/src/simsoptpp/py_shared_ptr.h deleted file mode 100644 index be7ecce99..000000000 --- a/src/simsoptpp/py_shared_ptr.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -// see https://github.com/pybind/pybind11/issues/1389 - -#include -#include -namespace py = pybind11; - -template class py_shared_ptr { - private: - std::shared_ptr _impl; - - public: - using element_type = T; - - py_shared_ptr(T *ptr) { - py::object pyobj = py::cast(ptr); - PyObject* pyptr = pyobj.ptr(); - Py_INCREF(pyptr); - std::shared_ptr vec_py_ptr( - pyptr, [](PyObject *ob) { Py_DECREF(ob); }); - _impl = std::shared_ptr(vec_py_ptr, ptr); - } - - py_shared_ptr(std::shared_ptr r): _impl(r){} - - operator std::shared_ptr() { return _impl; } - - T* get() const {return _impl.get();} -}; diff --git a/src/simsoptpp/python.cpp b/src/simsoptpp/python.cpp index 48e1977aa..f556ba541 100644 --- a/src/simsoptpp/python.cpp +++ b/src/simsoptpp/python.cpp @@ -1,8 +1,6 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "pybind11/functional.h" -#include "py_shared_ptr.h" -PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr); #define FORCE_IMPORT_ARRAY #include "xtensor-python/pyarray.hpp" // Numpy bindings typedef xt::pyarray PyArray; diff --git a/src/simsoptpp/python_curves.cpp b/src/simsoptpp/python_curves.cpp index 70e73e8b8..1f83b6127 100644 --- a/src/simsoptpp/python_curves.cpp +++ b/src/simsoptpp/python_curves.cpp @@ -2,10 +2,9 @@ #include "pybind11/stl.h" #include "xtensor-python/pyarray.hpp" // Numpy bindings typedef xt::pyarray PyArray; -#include "py_shared_ptr.h" -PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr); using std::shared_ptr; +namespace py = pybind11; #include "curve.h" #include "pycurve.h" @@ -90,16 +89,16 @@ template void register_common_curve_methods(S &c) { } void init_curves(py::module_ &m) { - auto pycurve = py::class_, PyCurveTrampoline>(m, "Curve") + auto pycurve = py::class_, PyCurveTrampoline>(m, "Curve") .def(py::init>()); register_common_curve_methods(pycurve); - auto pycurvexyzfourier = py::class_, PyCurveXYZFourierTrampoline, PyCurve>(m, "CurveXYZFourier") + auto pycurvexyzfourier = py::class_, PyCurveXYZFourierTrampoline, PyCurve>(m, "CurveXYZFourier") .def(py::init, int>()) .def_readonly("dofs", &PyCurveXYZFourier::dofs); register_common_curve_methods(pycurvexyzfourier); - auto pycurverzfourier = py::class_, PyCurveRZFourierTrampoline, PyCurve>(m, "CurveRZFourier") + auto pycurverzfourier = py::class_, PyCurveRZFourierTrampoline, PyCurve>(m, "CurveRZFourier") //.def(py::init()) .def(py::init, int, int, bool>()) .def_readwrite("rc", &PyCurveRZFourier::rc) diff --git a/src/simsoptpp/python_magneticfield.cpp b/src/simsoptpp/python_magneticfield.cpp index 1cff744ac..d60267144 100644 --- a/src/simsoptpp/python_magneticfield.cpp +++ b/src/simsoptpp/python_magneticfield.cpp @@ -5,11 +5,10 @@ #include "xtensor-python/pytensor.hpp" // Numpy bindings typedef xt::pyarray PyArray; typedef xt::pytensor PyTensor; -#include "py_shared_ptr.h" -PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr); using std::shared_ptr; using std::vector; +namespace py = pybind11; #include "magneticfield.h" #include "magneticfield_biotsavart.h" #include "magneticfield_interpolated.h" @@ -54,17 +53,17 @@ template void register_common_field_methods(S &c) { void init_magneticfields(py::module_ &m){ - py::class_>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.") + py::class_>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.") .def_readonly("degree", &InterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`."); - py::class_, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.") + py::class_, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.") .def(py::init()) .def_readonly("degree", &UniformInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`."); - py::class_, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.") + py::class_, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.") .def(py::init()) .def_readonly("degree", &ChebyshevInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`."); - py::class_, py_shared_ptr>>(m, "RegularGridInterpolant3D", + py::class_, shared_ptr>>(m, "RegularGridInterpolant3D", R"pbdoc( Interpolates a (vector valued) function on a uniform grid. This interpolant is optimized for fast function evaluation (at the cost of memory usage). The main purpose of this class is to be used to interpolate magnetic fields and then use the interpolant for tasks such as fieldline or particle tracing for which the field needs to be evaluated many many times. @@ -93,12 +92,12 @@ void init_magneticfields(py::module_ &m){ .def_readonly("curve", &Coil::curve, "Get the underlying curve.") .def_readonly("current", &Coil::current, "Get the underlying current."); - auto mf = py::class_, py_shared_ptr>(m, "MagneticField", "Abstract class representing magnetic fields.") + auto mf = py::class_, shared_ptr>(m, "MagneticField", "Abstract class representing magnetic fields.") .def(py::init<>()); register_common_field_methods(mf); //.def("B", py::overload_cast<>(&PyMagneticField::B)); - auto bs = py::class_, py_shared_ptr, PyMagneticField>(m, "BiotSavart") + auto bs = py::class_, shared_ptr, PyMagneticField>(m, "BiotSavart") .def(py::init>>>()) .def("compute", &PyBiotSavart::compute) .def("fieldcache_get_or_create", &PyBiotSavart::fieldcache_get_or_create) @@ -106,7 +105,7 @@ void init_magneticfields(py::module_ &m){ .def_readonly("coils", &PyBiotSavart::coils); register_common_field_methods(bs); - auto ifield = py::class_, PyMagneticField>(m, "InterpolatedField") + auto ifield = py::class_, PyMagneticField>(m, "InterpolatedField") .def(py::init, InterpolationRule, RangeTriplet, RangeTriplet, RangeTriplet, bool, int, bool>()) .def(py::init, int, RangeTriplet, RangeTriplet, RangeTriplet, bool, int, bool>()) .def("estimate_error_B", &PyInterpolatedField::estimate_error_B) diff --git a/src/simsoptpp/python_surfaces.cpp b/src/simsoptpp/python_surfaces.cpp index 606c9f628..0bd800630 100644 --- a/src/simsoptpp/python_surfaces.cpp +++ b/src/simsoptpp/python_surfaces.cpp @@ -2,11 +2,10 @@ #include "pybind11/stl.h" #include "xtensor-python/pyarray.hpp" // Numpy bindings typedef xt::pyarray PyArray; -#include "py_shared_ptr.h" -PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr); using std::shared_ptr; using std::vector; +namespace py = pybind11; #include "pycurve.h" #include "surface.h" #include "pysurface.h" diff --git a/src/simsoptpp/python_tracing.cpp b/src/simsoptpp/python_tracing.cpp index 30a1f6bc4..071ed09f9 100644 --- a/src/simsoptpp/python_tracing.cpp +++ b/src/simsoptpp/python_tracing.cpp @@ -1,12 +1,11 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "pybind11/functional.h" +namespace py = pybind11; #include "xtensor-python/pyarray.hpp" // Numpy bindings typedef xt::pyarray PyArray; #include "xtensor-python/pytensor.hpp" // Numpy bindings typedef xt::pytensor PyTensor; -#include "py_shared_ptr.h" -PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr); using std::shared_ptr; using std::vector; #include "tracing.h"