Skip to content

Commit

Permalink
Reduce memory consumption
Browse files Browse the repository at this point in the history
  • Loading branch information
vmarkovtsev committed Jun 11, 2024
1 parent 162474a commit bfd47ba
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
14 changes: 7 additions & 7 deletions lap.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,14 @@ find_umins(
/// @param u out dual variables, row reduction numbers / size dim
/// @param v out dual variables, column reduction numbers / size dim
/// @return achieved minimum assignment cost
template <bool avx2, typename idx, typename cost>
cost lap(int dim, const cost *restrict assign_cost, bool verbose,
template <bool avx2, bool verbose, typename idx, typename cost>
cost lap(int dim, const cost *restrict assign_cost,
idx *restrict rowsol, idx *restrict colsol,
cost *restrict u, cost *restrict v) {
auto free = std::unique_ptr<idx[]>(new idx[dim]); // list of unassigned rows.
auto collist = std::unique_ptr<idx[]>(new idx[dim]); // list of columns to be scanned in various ways.
auto matches = std::unique_ptr<idx[]>(new idx[dim]); // counts how many times a row could be assigned.
auto d = std::unique_ptr<cost[]>(new cost[dim]); // 'cost-distance' in augmenting path calculation.
auto pred = std::unique_ptr<idx[]>(new idx[dim]); // row-predecessor of column in augmenting/alternating path.
auto collist = std::make_unique<idx[]>(dim); // list of columns to be scanned in various ways.
auto matches = std::make_unique<idx[]>(dim); // counts how many times a row could be assigned.
auto d = std::make_unique<cost[]>(dim); // 'cost-distance' in augmenting path calculation.
auto pred = std::make_unique<idx[]>(dim); // row-predecessor of column in augmenting/alternating path.

// init how many times a row will be assigned in the column reduction.
#if _OPENMP >= 201307
Expand Down Expand Up @@ -273,6 +272,7 @@ cost lap(int dim, const cost *restrict assign_cost, bool verbose,
}

// REDUCTION TRANSFER
auto free = matches.get(); // list of unassigned rows.
idx numfree = 0;
for (idx i = 0; i < dim; i++) {
const cost *local_cost = &assign_cost[i * dim];
Expand Down
28 changes: 19 additions & 9 deletions python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ using pyobj = _pyobj<PyObject>;
using pyarray = _pyobj<PyArrayObject>;

template <typename F>
static always_inline double call_lap(int dim, const void *restrict cost_matrix, bool verbose,
static always_inline double call_lap(int dim, const void *restrict cost_matrix,
bool verbose, bool disable_avx,
int *restrict row_ind, int *restrict col_ind,
void *restrict u, void *restrict v) {
double lapcost;
Expand All @@ -76,10 +77,18 @@ static always_inline double call_lap(int dim, const void *restrict cost_matrix,
auto cost_matrix_typed = reinterpret_cast<const F*>(cost_matrix);
auto u_typed = reinterpret_cast<F*>(u);
auto v_typed = reinterpret_cast<F*>(v);
if (hasAVX2) {
lapcost = lap<true>(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
if (hasAVX2 && !disable_avx) {
if (verbose) {
lapcost = lap<true, true>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
} else {
lapcost = lap<true, false>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
}
} else {
lapcost = lap<false>(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
if (verbose) {
lapcost = lap<false, true>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
} else {
lapcost = lap<false, false>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
}
}
Py_END_ALLOW_THREADS
return lapcost;
Expand All @@ -88,12 +97,13 @@ static always_inline double call_lap(int dim, const void *restrict cost_matrix,
static PyObject *py_lapjv(PyObject *self, PyObject *args, PyObject *kwargs) {
PyObject *cost_matrix_obj;
int verbose = 0;
int disable_avx = 0;
int force_doubles = 0;
static const char *kwlist[] = {
"cost_matrix", "verbose", "force_doubles", NULL};
"cost_matrix", "verbose", "disable_avx", "force_doubles", NULL};
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, "O|pb", const_cast<char**>(kwlist),
&cost_matrix_obj, &verbose, &force_doubles)) {
args, kwargs, "O|pbb", const_cast<char**>(kwlist),
&cost_matrix_obj, &verbose, &disable_avx, &force_doubles)) {
return NULL;
}
pyarray cost_matrix_array;
Expand Down Expand Up @@ -144,9 +154,9 @@ static PyObject *py_lapjv(PyObject *self, PyObject *args, PyObject *kwargs) {
auto u = PyArray_DATA(u_array.get());
auto v = PyArray_DATA(v_array.get());
if (float32) {
lapcost = call_lap<float>(dim, cost_matrix, verbose, row_ind, col_ind, u, v);
lapcost = call_lap<float>(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
} else {
lapcost = call_lap<double>(dim, cost_matrix, verbose, row_ind, col_ind, u, v);
lapcost = call_lap<double>(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
}
return Py_BuildValue("(OO(dOO))",
row_ind_array.get(), col_ind_array.get(), lapcost,
Expand Down

0 comments on commit bfd47ba

Please sign in to comment.