-
Notifications
You must be signed in to change notification settings - Fork 29
/
python.cc
164 lines (154 loc) · 5.6 KB
/
python.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <functional>
#include <memory>
#include <Python.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/arrayobject.h>
#include "cpu_id.h"
#include "lap.h"
static SIMDFlags simd_flags = SIMDFlags();
static char module_docstring[] =
"This module wraps LAPJV - Jonker-Volgenant linear sum assignment algorithm.";
static char lapjv_docstring[] =
"Solves the linear sum assignment problem.";
static PyObject *py_lapjv(PyObject *self, PyObject *args, PyObject *kwargs);
static PyMethodDef module_functions[] = {
{"lapjv", reinterpret_cast<PyCFunction>(py_lapjv),
METH_VARARGS | METH_KEYWORDS, lapjv_docstring},
{NULL, NULL, 0, NULL}
};
extern "C" {
PyMODINIT_FUNC PyInit_lapjv(void) {
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"lapjv", /* m_name */
module_docstring, /* m_doc */
-1, /* m_size */
module_functions, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear */
NULL, /* m_free */
};
PyObject *m = PyModule_Create(&moduledef);
if (m == NULL) {
PyErr_SetString(PyExc_RuntimeError, "PyModule_Create() failed");
return NULL;
}
// numpy
import_array();
return m;
}
}
template <typename O>
using pyobj_parent = std::unique_ptr<O, std::function<void(O*)>>;
template <typename O>
class _pyobj : public pyobj_parent<O> {
public:
_pyobj() : pyobj_parent<O>(
nullptr, [](O *p){ if (p) Py_DECREF(p); }) {}
explicit _pyobj(PyObject *ptr) : pyobj_parent<O>(
reinterpret_cast<O *>(ptr), [](O *p){ if(p) Py_DECREF(p); }) {}
void reset(PyObject *p) noexcept {
pyobj_parent<O>::reset(reinterpret_cast<O*>(p));
}
};
using pyobj = _pyobj<PyObject>;
using pyarray = _pyobj<PyArrayObject>;
template <typename F>
static always_inline double call_lap(int dim, const void *restrict cost_matrix,
bool verbose, bool disable_avx,
int *restrict row_ind, int *restrict col_ind,
void *restrict u, void *restrict v) {
double lapcost;
Py_BEGIN_ALLOW_THREADS
bool hasAVX2 = simd_flags.hasAVX2();
if (verbose) {
printf("AVX2: %s\n", hasAVX2? "enabled" : "disabled");
}
auto cost_matrix_typed = reinterpret_cast<const F*>(cost_matrix);
auto u_typed = reinterpret_cast<F*>(u);
auto v_typed = reinterpret_cast<F*>(v);
if (hasAVX2 && !disable_avx) {
if (verbose) {
lapcost = lap<true, true>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
} else {
lapcost = lap<true, false>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
}
} else {
if (verbose) {
lapcost = lap<false, true>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
} else {
lapcost = lap<false, false>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
}
}
Py_END_ALLOW_THREADS
return lapcost;
}
static PyObject *py_lapjv(PyObject *self, PyObject *args, PyObject *kwargs) {
PyObject *cost_matrix_obj;
int verbose = 0;
int disable_avx = 0;
int force_doubles = 0;
static const char *kwlist[] = {
"cost_matrix", "verbose", "disable_avx", "force_doubles", NULL};
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, "O|pbb", const_cast<char**>(kwlist),
&cost_matrix_obj, &verbose, &disable_avx, &force_doubles)) {
return NULL;
}
pyarray cost_matrix_array;
bool float32 = true;
cost_matrix_array.reset(PyArray_FROM_OTF(
cost_matrix_obj, NPY_FLOAT32,
NPY_ARRAY_IN_ARRAY | (force_doubles? 0 : NPY_ARRAY_FORCECAST)));
if (!cost_matrix_array) {
PyErr_Clear();
float32 = false;
cost_matrix_array.reset(PyArray_FROM_OTF(
cost_matrix_obj, NPY_FLOAT64, NPY_ARRAY_IN_ARRAY));
if (!cost_matrix_array) {
PyErr_SetString(PyExc_ValueError, "\"cost_matrix\" must be a numpy array "
"of float32 or float64 dtype");
return NULL;
}
}
auto ndims = PyArray_NDIM(cost_matrix_array.get());
if (ndims != 2) {
PyErr_SetString(PyExc_ValueError,
"\"cost_matrix\" must be a square 2D numpy array");
return NULL;
}
auto dims = PyArray_DIMS(cost_matrix_array.get());
if (dims[0] != dims[1]) {
PyErr_SetString(PyExc_ValueError,
"\"cost_matrix\" must be a square 2D numpy array");
return NULL;
}
int dim = dims[0];
if (dim <= 0) {
PyErr_SetString(PyExc_ValueError,
"\"cost_matrix\"'s shape is invalid or too large");
return NULL;
}
auto cost_matrix = PyArray_DATA(cost_matrix_array.get());
npy_intp ret_dims[] = {dim, 0};
pyarray row_ind_array(PyArray_SimpleNew(1, ret_dims, NPY_INT));
pyarray col_ind_array(PyArray_SimpleNew(1, ret_dims, NPY_INT));
auto row_ind = reinterpret_cast<int*>(PyArray_DATA(row_ind_array.get()));
auto col_ind = reinterpret_cast<int*>(PyArray_DATA(col_ind_array.get()));
pyarray u_array(PyArray_SimpleNew(
1, ret_dims, float32? NPY_FLOAT32 : NPY_FLOAT64));
pyarray v_array(PyArray_SimpleNew(
1, ret_dims, float32? NPY_FLOAT32 : NPY_FLOAT64));
double lapcost;
auto u = PyArray_DATA(u_array.get());
auto v = PyArray_DATA(v_array.get());
if (float32) {
lapcost = call_lap<float>(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
} else {
lapcost = call_lap<double>(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
}
return Py_BuildValue("(OO(dOO))",
row_ind_array.get(), col_ind_array.get(), lapcost,
u_array.get(), v_array.get());
}