Skip to content

Commit

Permalink
Kick out py_shared_ptr and instead keep the python instance alive man…
Browse files Browse the repository at this point in the history
…ually by holding a reference to it

See pybind/pybind11#1389 for why py_shared_ptr was needed in the first place, and the comment from May 27 why we may not want to use it (reference cycle)
  • Loading branch information
florianwechsung committed Sep 23, 2021
1 parent 4bf5b85 commit 23bb44d
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 50 deletions.
1 change: 1 addition & 0 deletions src/simsopt/field/magneticfieldclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def __init__(self, field, degree, rrange, phirange, zrange, extrapolate=True, nf
logger.warning(fr"Sure about phirange[1]={phirange[1]}? When exploiting rotational symmetry, the interpolant is never evaluated for phi>2\pi/nfp.")

sopp.InterpolatedField.__init__(self, field, degree, rrange, phirange, zrange, extrapolate, nfp, stellsym)
self.__field == field

def to_vtk(self, filename, h=0.1):
"""Export the field evaluated on a regular grid for visualisation with e.g. Paraview."""
Expand Down
30 changes: 0 additions & 30 deletions src/simsoptpp/py_shared_ptr.h

This file was deleted.

2 changes: 0 additions & 2 deletions src/simsoptpp/python.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/functional.h"
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
#define FORCE_IMPORT_ARRAY
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
Expand Down
9 changes: 4 additions & 5 deletions src/simsoptpp/python_curves.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
#include "pybind11/stl.h"
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;

namespace py = pybind11;

#include "curve.h"
#include "pycurve.h"
Expand Down Expand Up @@ -90,16 +89,16 @@ template <typename T, typename S> void register_common_curve_methods(S &c) {
}

void init_curves(py::module_ &m) {
auto pycurve = py::class_<PyCurve, py_shared_ptr<PyCurve>, PyCurveTrampoline<PyCurve>>(m, "Curve")
auto pycurve = py::class_<PyCurve, shared_ptr<PyCurve>, PyCurveTrampoline<PyCurve>>(m, "Curve")
.def(py::init<vector<double>>());
register_common_curve_methods<PyCurve>(pycurve);

auto pycurvexyzfourier = py::class_<PyCurveXYZFourier, py_shared_ptr<PyCurveXYZFourier>, PyCurveXYZFourierTrampoline<PyCurveXYZFourier>, PyCurve>(m, "CurveXYZFourier")
auto pycurvexyzfourier = py::class_<PyCurveXYZFourier, shared_ptr<PyCurveXYZFourier>, PyCurveXYZFourierTrampoline<PyCurveXYZFourier>, PyCurve>(m, "CurveXYZFourier")
.def(py::init<vector<double>, int>())
.def_readonly("dofs", &PyCurveXYZFourier::dofs);
register_common_curve_methods<PyCurveXYZFourier>(pycurvexyzfourier);

auto pycurverzfourier = py::class_<PyCurveRZFourier, py_shared_ptr<PyCurveRZFourier>, PyCurveRZFourierTrampoline<PyCurveRZFourier>, PyCurve>(m, "CurveRZFourier")
auto pycurverzfourier = py::class_<PyCurveRZFourier, shared_ptr<PyCurveRZFourier>, PyCurveRZFourierTrampoline<PyCurveRZFourier>, PyCurve>(m, "CurveRZFourier")
//.def(py::init<int, int>())
.def(py::init<vector<double>, int, int, bool>())
.def_readwrite("rc", &PyCurveRZFourier::rc)
Expand Down
17 changes: 8 additions & 9 deletions src/simsoptpp/python_magneticfield.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
#include "xtensor-python/pytensor.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
typedef xt::pytensor<double, 2, xt::layout_type::row_major> PyTensor;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;
using std::vector;

namespace py = pybind11;
#include "magneticfield.h"
#include "magneticfield_biotsavart.h"
#include "magneticfield_interpolated.h"
Expand Down Expand Up @@ -54,17 +53,17 @@ template <typename T, typename S> void register_common_field_methods(S &c) {

void init_magneticfields(py::module_ &m){

py::class_<InterpolationRule, py_shared_ptr<InterpolationRule>>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.")
py::class_<InterpolationRule, shared_ptr<InterpolationRule>>(m, "InterpolationRule", "Abstract class for interpolation rules on an interval.")
.def_readonly("degree", &InterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");

py::class_<UniformInterpolationRule, py_shared_ptr<UniformInterpolationRule>, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.")
py::class_<UniformInterpolationRule, shared_ptr<UniformInterpolationRule>, InterpolationRule>(m, "UniformInterpolationRule", "Polynomial interpolation using equispaced points.")
.def(py::init<int>())
.def_readonly("degree", &UniformInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");
py::class_<ChebyshevInterpolationRule, py_shared_ptr<ChebyshevInterpolationRule>, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.")
py::class_<ChebyshevInterpolationRule, shared_ptr<ChebyshevInterpolationRule>, InterpolationRule>(m, "ChebyshevInterpolationRule", "Polynomial interpolation using chebychev points.")
.def(py::init<int>())
.def_readonly("degree", &ChebyshevInterpolationRule::degree, "The degree of the polynomial. The number of interpolation points in `degree+1`.");

py::class_<RegularGridInterpolant3D<PyTensor>, py_shared_ptr<RegularGridInterpolant3D<PyTensor>>>(m, "RegularGridInterpolant3D",
py::class_<RegularGridInterpolant3D<PyTensor>, shared_ptr<RegularGridInterpolant3D<PyTensor>>>(m, "RegularGridInterpolant3D",
R"pbdoc(
Interpolates a (vector valued) function on a uniform grid.
This interpolant is optimized for fast function evaluation (at the cost of memory usage). The main purpose of this class is to be used to interpolate magnetic fields and then use the interpolant for tasks such as fieldline or particle tracing for which the field needs to be evaluated many many times.
Expand Down Expand Up @@ -93,20 +92,20 @@ void init_magneticfields(py::module_ &m){
.def_readonly("curve", &Coil<PyArray>::curve, "Get the underlying curve.")
.def_readonly("current", &Coil<PyArray>::current, "Get the underlying current.");

auto mf = py::class_<PyMagneticField, PyMagneticFieldTrampoline<PyMagneticField>, py_shared_ptr<PyMagneticField>>(m, "MagneticField", "Abstract class representing magnetic fields.")
auto mf = py::class_<PyMagneticField, PyMagneticFieldTrampoline<PyMagneticField>, shared_ptr<PyMagneticField>>(m, "MagneticField", "Abstract class representing magnetic fields.")
.def(py::init<>());
register_common_field_methods<PyMagneticField>(mf);
//.def("B", py::overload_cast<>(&PyMagneticField::B));

auto bs = py::class_<PyBiotSavart, PyMagneticFieldTrampoline<PyBiotSavart>, py_shared_ptr<PyBiotSavart>, PyMagneticField>(m, "BiotSavart")
auto bs = py::class_<PyBiotSavart, PyMagneticFieldTrampoline<PyBiotSavart>, shared_ptr<PyBiotSavart>, PyMagneticField>(m, "BiotSavart")
.def(py::init<vector<shared_ptr<Coil<PyArray>>>>())
.def("compute", &PyBiotSavart::compute)
.def("fieldcache_get_or_create", &PyBiotSavart::fieldcache_get_or_create)
.def("fieldcache_get_status", &PyBiotSavart::fieldcache_get_status)
.def_readonly("coils", &PyBiotSavart::coils);
register_common_field_methods<PyBiotSavart>(bs);

auto ifield = py::class_<PyInterpolatedField, py_shared_ptr<PyInterpolatedField>, PyMagneticField>(m, "InterpolatedField")
auto ifield = py::class_<PyInterpolatedField, shared_ptr<PyInterpolatedField>, PyMagneticField>(m, "InterpolatedField")
.def(py::init<shared_ptr<PyMagneticField>, InterpolationRule, RangeTriplet, RangeTriplet, RangeTriplet, bool, int, bool>())
.def(py::init<shared_ptr<PyMagneticField>, int, RangeTriplet, RangeTriplet, RangeTriplet, bool, int, bool>())
.def("estimate_error_B", &PyInterpolatedField::estimate_error_B)
Expand Down
3 changes: 1 addition & 2 deletions src/simsoptpp/python_surfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
#include "pybind11/stl.h"
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;
using std::vector;

namespace py = pybind11;
#include "pycurve.h"
#include "surface.h"
#include "pysurface.h"
Expand Down
3 changes: 1 addition & 2 deletions src/simsoptpp/python_tracing.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/functional.h"
namespace py = pybind11;
#include "xtensor-python/pyarray.hpp" // Numpy bindings
typedef xt::pyarray<double> PyArray;
#include "xtensor-python/pytensor.hpp" // Numpy bindings
typedef xt::pytensor<double, 2, xt::layout_type::row_major> PyTensor;
#include "py_shared_ptr.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, py_shared_ptr<T>);
using std::shared_ptr;
using std::vector;
#include "tracing.h"
Expand Down

0 comments on commit 23bb44d

Please sign in to comment.