From bb2f6bd6d33ae6530b487289d94ba4d056ddd79c Mon Sep 17 00:00:00 2001 From: KHLee Date: Sun, 12 Jan 2025 01:17:32 +0800 Subject: [PATCH 01/19] Add basic in-place qsort to SimpleArray --- cpp/modmesh/buffer/Algorithm.hpp | 61 +++++++++++++++++++ cpp/modmesh/buffer/CMakeLists.txt | 1 + cpp/modmesh/buffer/SimpleArray.hpp | 9 +++ cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 4 ++ 4 files changed, 75 insertions(+) create mode 100644 cpp/modmesh/buffer/Algorithm.hpp diff --git a/cpp/modmesh/buffer/Algorithm.hpp b/cpp/modmesh/buffer/Algorithm.hpp new file mode 100644 index 00000000..814a0c7e --- /dev/null +++ b/cpp/modmesh/buffer/Algorithm.hpp @@ -0,0 +1,61 @@ +#pragma once + +namespace modmesh { + +namespace detail { + +template +static void swap(T &a, T &b){ + if (a == b) { + return; + } + T tmp = a; + a = b; + b = tmp; +} + +template +static int compare(void *a, void *b){ + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or floating-point type"); + + if (a == nullptr || b == nullptr) { + throw std::invalid_argument(Formatter() << "Null pointer shouldn't be sent into compare function"); + } + if (a == b) { + return 0; + } + return *static_cast(a) - *static_cast(b); +} + +template +void qsort(T *begin, T *end, int (*cmp)(void *, void *) = compare) { + ssize_t N = end - begin; + if (N < 2) { + return; + } + + T *end_pos = end - 1; + T *cur = begin + 1; + T *pivot = begin; + + while (cur <= end_pos) { + if (cmp(cur, pivot) < 0) { + cur++; + } else { + swap(*cur, *end_pos); + end_pos--; + } + } + swap(*pivot, *end_pos); + pivot = end_pos; + + qsort(begin, pivot, cmp); + qsort(pivot + 1, end, cmp); +} + +} /* end namespace detail */ + +} /* end namespace modmesh */ + +// vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: diff --git a/cpp/modmesh/buffer/CMakeLists.txt b/cpp/modmesh/buffer/CMakeLists.txt index ba88ca0b..66238c4c 100644 --- a/cpp/modmesh/buffer/CMakeLists.txt +++ b/cpp/modmesh/buffer/CMakeLists.txt @@ -11,6 +11,7 @@ set(MODMESH_BUFFER_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/BufferExpander.hpp ${CMAKE_CURRENT_SOURCE_DIR}/SimpleArray.hpp ${CMAKE_CURRENT_SOURCE_DIR}/SimpleCollector.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/Algorithm.hpp CACHE FILEPATH "" FORCE) set(MODMESH_BUFFER_SOURCES diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 7e398b05..546398a2 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -29,6 +29,7 @@ */ #include +#include #include #include @@ -585,6 +586,14 @@ class SimpleArray } } + void sort(void) + { + if (ndim() != 1){ + throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); + } + detail::qsort(begin(), end()); + } + template value_type const & operator()(Args... args) const { return *vptr(args...); } template diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index 29919846..b5902e26 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -147,6 +147,10 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray "reshape", [](wrapped_type const & self, py::object const & shape) { return self.reshape(make_shape(shape)); }) + .def( + "sort", + [](wrapped_type & self) + { self.sort(); }) .def_property_readonly("has_ghost", &wrapped_type::has_ghost) .def_property("nghost", &wrapped_type::nghost, &wrapped_type::set_nghost) .def_property_readonly("nbody", &wrapped_type::nbody) From a381336105a5f8ce4d3d19405508167943bc3bce Mon Sep 17 00:00:00 2001 From: KHLee Date: Tue, 21 Jan 2025 22:56:59 +0800 Subject: [PATCH 02/19] Replace self-implemented qsort by std::sort --- cpp/modmesh/buffer/Algorithm.hpp | 61 ------------------------------ cpp/modmesh/buffer/CMakeLists.txt | 1 - cpp/modmesh/buffer/SimpleArray.hpp | 5 ++- 3 files changed, 3 insertions(+), 64 deletions(-) delete mode 100644 cpp/modmesh/buffer/Algorithm.hpp diff --git a/cpp/modmesh/buffer/Algorithm.hpp b/cpp/modmesh/buffer/Algorithm.hpp deleted file mode 100644 index 814a0c7e..00000000 --- a/cpp/modmesh/buffer/Algorithm.hpp +++ /dev/null @@ -1,61 +0,0 @@ -#pragma once - -namespace modmesh { - -namespace detail { - -template -static void swap(T &a, T &b){ - if (a == b) { - return; - } - T tmp = a; - a = b; - b = tmp; -} - -template -static int compare(void *a, void *b){ - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or floating-point type"); - - if (a == nullptr || b == nullptr) { - throw std::invalid_argument(Formatter() << "Null pointer shouldn't be sent into compare function"); - } - if (a == b) { - return 0; - } - return *static_cast(a) - *static_cast(b); -} - -template -void qsort(T *begin, T *end, int (*cmp)(void *, void *) = compare) { - ssize_t N = end - begin; - if (N < 2) { - return; - } - - T *end_pos = end - 1; - T *cur = begin + 1; - T *pivot = begin; - - while (cur <= end_pos) { - if (cmp(cur, pivot) < 0) { - cur++; - } else { - swap(*cur, *end_pos); - end_pos--; - } - } - swap(*pivot, *end_pos); - pivot = end_pos; - - qsort(begin, pivot, cmp); - qsort(pivot + 1, end, cmp); -} - -} /* end namespace detail */ - -} /* end namespace modmesh */ - -// vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: diff --git a/cpp/modmesh/buffer/CMakeLists.txt b/cpp/modmesh/buffer/CMakeLists.txt index 66238c4c..ba88ca0b 100644 --- a/cpp/modmesh/buffer/CMakeLists.txt +++ b/cpp/modmesh/buffer/CMakeLists.txt @@ -11,7 +11,6 @@ set(MODMESH_BUFFER_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/BufferExpander.hpp ${CMAKE_CURRENT_SOURCE_DIR}/SimpleArray.hpp ${CMAKE_CURRENT_SOURCE_DIR}/SimpleCollector.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/Algorithm.hpp CACHE FILEPATH "" FORCE) set(MODMESH_BUFFER_SOURCES diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 546398a2..66677f36 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -29,12 +29,12 @@ */ #include -#include #include #include #include #include +#include #if defined(_MSC_VER) #include @@ -591,7 +591,8 @@ class SimpleArray if (ndim() != 1){ throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); } - detail::qsort(begin(), end()); + + std::sort(begin(), end()); } template From 94d78d6b404f71a44bd87aecd6bc4fca6e80217e Mon Sep 17 00:00:00 2001 From: KHLee Date: Tue, 21 Jan 2025 22:58:05 +0800 Subject: [PATCH 03/19] Implement argsort in c++ --- cpp/modmesh/buffer/SimpleArray.hpp | 71 ++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 66677f36..29fd701c 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -595,6 +595,77 @@ class SimpleArray std::sort(begin(), end()); } + SimpleArray argsort(void) + { + if (ndim() != 1){ + throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); + } + + SimpleArray ret(shape()); + + { // Return array initialization + size_t cnt = 0; + std::for_each(ret.begin(), ret.end(), [&cnt](size_t &v){v = cnt++;}); + } + + value_type const *buf = body(); + auto cmp = [buf](size_t a, size_t b) { + return buf[a] < buf[b]; + }; + std::sort(ret.begin(), ret.end(), cmp); + return ret; + } + + void apply_argsort(SimpleArray const &sorted_args) + { + if (ndim() != 1 || sorted_args.ndim() != 1){ + throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); + } + if (shape()[0] != sorted_args.shape()[0]){ + throw std::runtime_error("SimpleArray: argsort only support same shape"); + } + if (shape()[0] < 2) { + return; + } + + std::vector applied_arg(shape()[0], false); + + auto all = [](std::vector &vec) { + for (auto i : vec) { + if (i == false) return false; + } + return true; + }; + + auto next = [](std::vector &vec, ssize_t last) { + for (ssize_t i = last; i < static_cast(vec.size()); i++) { + if (vec[i] == false) return i; + } + return static_cast(-1); + }; + + ssize_t idx = 0; + while (!all(applied_arg)) { + idx = next(applied_arg, idx); + if (idx == -1) break; + + value_type val = at(idx); + + ssize_t dst_idx = idx; + ssize_t src_idx = sorted_args[dst_idx]; + + while(src_idx != idx) { + at(dst_idx) = at(src_idx); + applied_arg.at(dst_idx) = true; + dst_idx = src_idx; + src_idx = sorted_args[dst_idx]; + } + + at(dst_idx) = val; + applied_arg.at(dst_idx) = true; + } + } + template value_type const & operator()(Args... args) const { return *vptr(args...); } template From 00bc2b630dd27eae352b99bb69786516d2c3f37b Mon Sep 17 00:00:00 2001 From: KHLee Date: Wed, 22 Jan 2025 21:05:56 +0800 Subject: [PATCH 04/19] Make return array of argsort into uint64_t and wrap argsort In order to make the datatype of return value of SimpleArray::argsort compatible to the wrapped python datatypes, the return value is changed to SimpleArray. Function argsort is also wrapped in this commit. --- cpp/modmesh/buffer/SimpleArray.hpp | 12 ++++++------ cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 4 ++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 29fd701c..953e8345 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -595,28 +595,28 @@ class SimpleArray std::sort(begin(), end()); } - SimpleArray argsort(void) + SimpleArray argsort(void) { if (ndim() != 1){ throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); } - SimpleArray ret(shape()); + SimpleArray ret(shape()); { // Return array initialization - size_t cnt = 0; - std::for_each(ret.begin(), ret.end(), [&cnt](size_t &v){v = cnt++;}); + uint64_t cnt = 0; + std::for_each(ret.begin(), ret.end(), [&cnt](uint64_t &v){v = cnt++;}); } value_type const *buf = body(); - auto cmp = [buf](size_t a, size_t b) { + auto cmp = [buf](uint64_t a, uint64_t b) { return buf[a] < buf[b]; }; std::sort(ret.begin(), ret.end(), cmp); return ret; } - void apply_argsort(SimpleArray const &sorted_args) + void apply_argsort(SimpleArray const &sorted_args) { if (ndim() != 1 || sorted_args.ndim() != 1){ throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index b5902e26..d0655a96 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -151,6 +151,10 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray "sort", [](wrapped_type & self) { self.sort(); }) + .def( + "argsort", + [](wrapped_type & self) + { return pybind11::cast(self.argsort()); }) .def_property_readonly("has_ghost", &wrapped_type::has_ghost) .def_property("nghost", &wrapped_type::nghost, &wrapped_type::set_nghost) .def_property_readonly("nbody", &wrapped_type::nbody) From 3ac729a31e823e0231e9eedb90b3d213cb63696f Mon Sep 17 00:00:00 2001 From: KHLee Date: Wed, 22 Jan 2025 21:15:40 +0800 Subject: [PATCH 05/19] Add basic test case for sorting of SimpleArray --- tests/test_buffer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 1ac6a2db..219bf4ad 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -824,6 +824,18 @@ def test_SimpleArray_SimpleArrayPlex_type_switch(self): self.assertEqual( str(type(arrayplex_int32_2)), "") + def test_sort(self): + narr = np.random.randint(0, 100, 20, dtype='int32') + sarr = modmesh.SimpleArrayInt32(array=narr); + args = sarr.argsort() + for i in range(1, len(args)): + self.assertLessEqual(sarr[args[i]], sarr[args[i]]) + + sarr.sort() + for i in range(1, len(args)): + self.assertLessEqual(sarr[i - 1], sarr[i]) + + class SimpleArrayCalculatorsTC(unittest.TestCase): @@ -1068,5 +1080,7 @@ def test_construct(self): for it in range(6): ct[it] = it + 10 self.assertEqual(list(it + 10 for it in range(6)), list(ct)) + + # vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: From 6a7f163f12d331901b16de7923bd17d9ccb532e4 Mon Sep 17 00:00:00 2001 From: KHLee Date: Sun, 2 Feb 2025 19:33:28 +0800 Subject: [PATCH 06/19] Lint files with linter --- cpp/modmesh/buffer/SimpleArray.hpp | 56 +++++++++++++++++++----------- tests/test_buffer.py | 5 +-- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 953e8345..426c41ba 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -588,7 +588,8 @@ class SimpleArray void sort(void) { - if (ndim() != 1){ + if (ndim() != 1) + { throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); } @@ -597,64 +598,79 @@ class SimpleArray SimpleArray argsort(void) { - if (ndim() != 1){ + if (ndim() != 1) + { throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); } SimpleArray ret(shape()); - { // Return array initialization + { // Return array initialization uint64_t cnt = 0; - std::for_each(ret.begin(), ret.end(), [&cnt](uint64_t &v){v = cnt++;}); - } + std::for_each(ret.begin(), ret.end(), [&cnt](uint64_t & v) + { v = cnt++; }); + } - value_type const *buf = body(); - auto cmp = [buf](uint64_t a, uint64_t b) { + value_type const * buf = body(); + auto cmp = [buf](uint64_t a, uint64_t b) + { return buf[a] < buf[b]; }; std::sort(ret.begin(), ret.end(), cmp); return ret; } - void apply_argsort(SimpleArray const &sorted_args) + void apply_argsort(SimpleArray const & sorted_args) { - if (ndim() != 1 || sorted_args.ndim() != 1){ + if (ndim() != 1 || sorted_args.ndim() != 1) + { throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); } - if (shape()[0] != sorted_args.shape()[0]){ + if (shape()[0] != sorted_args.shape()[0]) + { throw std::runtime_error("SimpleArray: argsort only support same shape"); } - if (shape()[0] < 2) { + if (shape()[0] < 2) + { return; } std::vector applied_arg(shape()[0], false); - auto all = [](std::vector &vec) { - for (auto i : vec) { - if (i == false) return false; + auto all = [](std::vector & vec) + { + for (auto i : vec) + { + if (i == false) + return false; } return true; }; - auto next = [](std::vector &vec, ssize_t last) { - for (ssize_t i = last; i < static_cast(vec.size()); i++) { - if (vec[i] == false) return i; + auto next = [](std::vector & vec, ssize_t last) + { + for (ssize_t i = last; i < static_cast(vec.size()); i++) + { + if (vec[i] == false) + return i; } return static_cast(-1); }; ssize_t idx = 0; - while (!all(applied_arg)) { + while (!all(applied_arg)) + { idx = next(applied_arg, idx); - if (idx == -1) break; + if (idx == -1) + break; value_type val = at(idx); ssize_t dst_idx = idx; ssize_t src_idx = sorted_args[dst_idx]; - while(src_idx != idx) { + while (src_idx != idx) + { at(dst_idx) = at(src_idx); applied_arg.at(dst_idx) = true; dst_idx = src_idx; diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 219bf4ad..d95edc74 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -826,7 +826,7 @@ def test_SimpleArray_SimpleArrayPlex_type_switch(self): def test_sort(self): narr = np.random.randint(0, 100, 20, dtype='int32') - sarr = modmesh.SimpleArrayInt32(array=narr); + sarr = modmesh.SimpleArrayInt32(array=narr) args = sarr.argsort() for i in range(1, len(args)): self.assertLessEqual(sarr[args[i]], sarr[args[i]]) @@ -836,7 +836,6 @@ def test_sort(self): self.assertLessEqual(sarr[i - 1], sarr[i]) - class SimpleArrayCalculatorsTC(unittest.TestCase): def test_minmaxsum(self): @@ -1080,7 +1079,5 @@ def test_construct(self): for it in range(6): ct[it] = it + 10 self.assertEqual(list(it + 10 for it in range(6)), list(ct)) - - # vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: From 6dd619de8fd7ad683e4f9bba5c899ccd8ce78eb2 Mon Sep 17 00:00:00 2001 From: KHLee Date: Mon, 3 Feb 2025 19:29:15 +0800 Subject: [PATCH 07/19] Move sort to SimpleArrayMixinCalculator --- cpp/modmesh/buffer/SimpleArray.hpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 426c41ba..54f7d39b 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -193,6 +193,17 @@ class SimpleArrayMixinCalculators } return ret; } + + void sort(void) + { + auto athis = static_cast(this); + if (athis->ndim() != 1) + { + throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " << athis->ndim() << "D array is currently not supported"); + } + + std::sort(athis->begin(), athis->end()); + } }; /* end class SimpleArrayMixinCalculators */ } /* end namespace detail */ @@ -586,16 +597,6 @@ class SimpleArray } } - void sort(void) - { - if (ndim() != 1) - { - throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); - } - - std::sort(begin(), end()); - } - SimpleArray argsort(void) { if (ndim() != 1) From 1f59f92874cabf7968ac3073bcf7f57698375903 Mon Sep 17 00:00:00 2001 From: KHLee Date: Mon, 3 Feb 2025 21:24:11 +0800 Subject: [PATCH 08/19] Move apply_argsort out from class declaration --- cpp/modmesh/buffer/SimpleArray.hpp | 116 ++++++++++++++--------------- 1 file changed, 54 insertions(+), 62 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 54f7d39b..b34c380c 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -199,7 +199,8 @@ class SimpleArrayMixinCalculators auto athis = static_cast(this); if (athis->ndim() != 1) { - throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " << athis->ndim() << "D array is currently not supported"); + throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " + << athis->ndim() << "D array is currently not supported"); } std::sort(athis->begin(), athis->end()); @@ -621,67 +622,7 @@ class SimpleArray return ret; } - void apply_argsort(SimpleArray const & sorted_args) - { - if (ndim() != 1 || sorted_args.ndim() != 1) - { - throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); - } - if (shape()[0] != sorted_args.shape()[0]) - { - throw std::runtime_error("SimpleArray: argsort only support same shape"); - } - if (shape()[0] < 2) - { - return; - } - - std::vector applied_arg(shape()[0], false); - - auto all = [](std::vector & vec) - { - for (auto i : vec) - { - if (i == false) - return false; - } - return true; - }; - - auto next = [](std::vector & vec, ssize_t last) - { - for (ssize_t i = last; i < static_cast(vec.size()); i++) - { - if (vec[i] == false) - return i; - } - return static_cast(-1); - }; - - ssize_t idx = 0; - while (!all(applied_arg)) - { - idx = next(applied_arg, idx); - if (idx == -1) - break; - - value_type val = at(idx); - - ssize_t dst_idx = idx; - ssize_t src_idx = sorted_args[dst_idx]; - - while (src_idx != idx) - { - at(dst_idx) = at(src_idx); - applied_arg.at(dst_idx) = true; - dst_idx = src_idx; - src_idx = sorted_args[dst_idx]; - } - - at(dst_idx) = val; - applied_arg.at(dst_idx) = true; - } - } + void apply_argsort(SimpleArray const & sorted_args); template value_type const & operator()(Args... args) const { return *vptr(args...); } @@ -827,6 +768,57 @@ class SimpleArray value_type * m_body = nullptr; }; /* end class SimpleArray */ +template +void SimpleArray::apply_argsort(SimpleArray const & sorted_args) +{ + if (ndim() != 1 || sorted_args.ndim() != 1) + { + throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); + } + if (shape()[0] != sorted_args.shape()[0]) + { + throw std::runtime_error("SimpleArray: argsort only support same shape"); + } + if (shape()[0] < 2) + { + return; + } + + std::vector applied_arg(shape()[0], false); + + auto next = [](std::vector & vec, ssize_t last) + { + for (ssize_t i = last; i < static_cast(vec.size()); ++i) + { + if (vec.at(i) == false) + { + return i; + } + } + return static_cast(-1); + }; + + ssize_t idx = 0; + while ((idx = next(applied_arg, idx)) != -1) + { + value_type val = at(idx); + + ssize_t dst_idx = idx; + ssize_t src_idx = sorted_args[dst_idx]; + + while (src_idx != idx) + { + at(dst_idx) = at(src_idx); + applied_arg.at(dst_idx) = true; + dst_idx = src_idx; + src_idx = sorted_args[dst_idx]; + } + + at(dst_idx) = val; + applied_arg.at(dst_idx) = true; + } +} + template using is_simple_array = std::is_same< std::remove_reference_t, From 1998843e2899809fa52d2154edc46d557b65751c Mon Sep 17 00:00:00 2001 From: KHLee Date: Mon, 3 Feb 2025 21:53:26 +0800 Subject: [PATCH 09/19] Rename apply_argsort to take_along_axis to maintain consistency with numpy --- cpp/modmesh/buffer/SimpleArray.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index b34c380c..de7616b8 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -199,8 +199,8 @@ class SimpleArrayMixinCalculators auto athis = static_cast(this); if (athis->ndim() != 1) { - throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " - << athis->ndim() << "D array is currently not supported"); + throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " + << athis->ndim() << "D array is currently not supported"); } std::sort(athis->begin(), athis->end()); @@ -622,7 +622,7 @@ class SimpleArray return ret; } - void apply_argsort(SimpleArray const & sorted_args); + void take_along_axis(SimpleArray const & sorted_args); template value_type const & operator()(Args... args) const { return *vptr(args...); } @@ -769,7 +769,7 @@ class SimpleArray }; /* end class SimpleArray */ template -void SimpleArray::apply_argsort(SimpleArray const & sorted_args) +void SimpleArray::take_along_axis(SimpleArray const & sorted_args) { if (ndim() != 1 || sorted_args.ndim() != 1) { From 45b2d4d5690522264cf84ef37aef6b9fe411db60 Mon Sep 17 00:00:00 2001 From: KHLee Date: Mon, 3 Feb 2025 22:23:11 +0800 Subject: [PATCH 10/19] Rename function parameters and refine error messages --- cpp/modmesh/buffer/SimpleArray.hpp | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index de7616b8..41cd8f3a 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -602,7 +602,8 @@ class SimpleArray { if (ndim() != 1) { - throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); + throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " + << ndim() << "D array is currently not supported"); } SimpleArray ret(shape()); @@ -622,7 +623,7 @@ class SimpleArray return ret; } - void take_along_axis(SimpleArray const & sorted_args); + void take_along_axis(SimpleArray const & indices); template value_type const & operator()(Args... args) const { return *vptr(args...); } @@ -769,15 +770,22 @@ class SimpleArray }; /* end class SimpleArray */ template -void SimpleArray::take_along_axis(SimpleArray const & sorted_args) +void SimpleArray::take_along_axis(SimpleArray const & indices) { - if (ndim() != 1 || sorted_args.ndim() != 1) + if (indices.ndim() != 1) + { + throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " + << indices.ndim() << "D indices is not supported."); + } + if (ndim() != 1) { - throw std::runtime_error("SimpleArray: Sorting is only supported in 1D array."); + throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " + << ndim() << "D array is not supported to be sorted."); } - if (shape()[0] != sorted_args.shape()[0]) + if (shape()[0] != indices.shape()[0]) { - throw std::runtime_error("SimpleArray: argsort only support same shape"); + throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis only support same shape of indices and array." + << "Array size " << shape()[0] << " != indices shape " << indices.shape()[0]); } if (shape()[0] < 2) { @@ -804,14 +812,14 @@ void SimpleArray::take_along_axis(SimpleArray const & sorted_args) value_type val = at(idx); ssize_t dst_idx = idx; - ssize_t src_idx = sorted_args[dst_idx]; + ssize_t src_idx = indices[dst_idx]; while (src_idx != idx) { at(dst_idx) = at(src_idx); applied_arg.at(dst_idx) = true; dst_idx = src_idx; - src_idx = sorted_args[dst_idx]; + src_idx = indices[dst_idx]; } at(dst_idx) = val; From 0e803b53e5ee03ddba354a7d6d40d64f60cef5d3 Mon Sep 17 00:00:00 2001 From: khlee529 Date: Tue, 4 Feb 2025 22:05:01 +0800 Subject: [PATCH 11/19] Refine sort and argsort --- cpp/modmesh/buffer/SimpleArray.hpp | 130 ++++++++---------- cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 5 +- 2 files changed, 60 insertions(+), 75 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 41cd8f3a..b3b3b452 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -44,6 +44,9 @@ typedef SSIZE_T ssize_t; namespace modmesh { +template +class SimpleArray; // forward declaration + namespace detail { @@ -194,17 +197,9 @@ class SimpleArrayMixinCalculators return ret; } - void sort(void) - { - auto athis = static_cast(this); - if (athis->ndim() != 1) - { - throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " - << athis->ndim() << "D array is currently not supported"); - } + void sort(void); + SimpleArray argsort(void); - std::sort(athis->begin(), athis->end()); - } }; /* end class SimpleArrayMixinCalculators */ } /* end namespace detail */ @@ -598,32 +593,8 @@ class SimpleArray } } - SimpleArray argsort(void) - { - if (ndim() != 1) - { - throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " - << ndim() << "D array is currently not supported"); - } - - SimpleArray ret(shape()); - - { // Return array initialization - uint64_t cnt = 0; - std::for_each(ret.begin(), ret.end(), [&cnt](uint64_t & v) - { v = cnt++; }); - } - - value_type const * buf = body(); - auto cmp = [buf](uint64_t a, uint64_t b) - { - return buf[a] < buf[b]; - }; - std::sort(ret.begin(), ret.end(), cmp); - return ret; - } - - void take_along_axis(SimpleArray const & indices); + template + SimpleArray take_along_axis(SimpleArray const & indices); template value_type const & operator()(Args... args) const { return *vptr(args...); } @@ -770,61 +741,78 @@ class SimpleArray }; /* end class SimpleArray */ template -void SimpleArray::take_along_axis(SimpleArray const & indices) +template +SimpleArray SimpleArray::take_along_axis(SimpleArray const & indices) { + static_assert(std::is_integral_v, "I must be integral type"); if (indices.ndim() != 1) { - throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " - << indices.ndim() << "D indices is not supported."); + throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports only " + "in 1D array but the index array is " + << indices.ndim() << " dimension"); } if (ndim() != 1) { - throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " - << ndim() << "D array is not supported to be sorted."); + throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports only " + "in 1D array but the array is " + << ndim() << " dimension"); } if (shape()[0] != indices.shape()[0]) { - throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis only support same shape of indices and array." - << "Array size " << shape()[0] << " != indices shape " << indices.shape()[0]); + throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports same " + "shape of indices and array. Array size " + << shape()[0] << " != indices array size " << indices.shape()[0]); } - if (shape()[0] < 2) + + SimpleArray ret(*this); + for (size_t i = 0; i < shape()[0]; ++i) { - return; + ret.at(i) = at(static_cast(indices[i])); } - std::vector applied_arg(shape()[0], false); + return ret; +} - auto next = [](std::vector & vec, ssize_t last) +template +void detail::SimpleArrayMixinCalculators::sort(void) +{ + auto athis = static_cast(this); + if (athis->ndim() != 1) { - for (ssize_t i = last; i < static_cast(vec.size()); ++i) - { - if (vec.at(i) == false) - { - return i; - } - } - return static_cast(-1); - }; + throw std::runtime_error(Formatter() << "SimpleArray: sort() supports only in 1D array " + " but the array is " + << athis->ndim() << " dimension"); + } - ssize_t idx = 0; - while ((idx = next(applied_arg, idx)) != -1) - { - value_type val = at(idx); + std::sort(athis->begin(), athis->end()); +} - ssize_t dst_idx = idx; - ssize_t src_idx = indices[dst_idx]; +template +SimpleArray detail::SimpleArrayMixinCalculators::argsort(void) +{ + auto athis = static_cast(this); + if (athis->ndim() != 1) + { + throw std::runtime_error(Formatter() << "SimpleArray: argsort() supports only in 1D array " + " but the array is " + << athis->ndim() << " dimension"); + } - while (src_idx != idx) - { - at(dst_idx) = at(src_idx); - applied_arg.at(dst_idx) = true; - dst_idx = src_idx; - src_idx = indices[dst_idx]; - } + SimpleArray ret(athis->shape()); - at(dst_idx) = val; - applied_arg.at(dst_idx) = true; + { // Return array initialization + uint64_t cnt = 0; + std::for_each(ret.begin(), ret.end(), [&cnt](uint64_t & v) + { v = cnt++; }); } + + value_type const * buf = athis->body(); + auto cmp = [buf](uint64_t a, uint64_t b) + { + return buf[a] < buf[b]; + }; + std::sort(ret.begin(), ret.end(), cmp); + return ret; } template diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index d0655a96..d5b8e8a0 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -147,10 +147,6 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray "reshape", [](wrapped_type const & self, py::object const & shape) { return self.reshape(make_shape(shape)); }) - .def( - "sort", - [](wrapped_type & self) - { self.sort(); }) .def( "argsort", [](wrapped_type & self) @@ -193,6 +189,7 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray .def("max", &wrapped_type::max) .def("sum", &wrapped_type::sum) .def("abs", &wrapped_type::abs) + .def("sort", &wrapped_type::sort) // ; From 111175e11468a0cb8423a4cc01f79eeeac2829ee Mon Sep 17 00:00:00 2001 From: khlee529 Date: Tue, 4 Feb 2025 22:33:03 +0800 Subject: [PATCH 12/19] Wrap take_along_axis --- cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 3 +++ tests/test_buffer.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index d5b8e8a0..0fef6396 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -151,6 +151,9 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray "argsort", [](wrapped_type & self) { return pybind11::cast(self.argsort()); }) + .def("take_along_axis", + [](wrapped_type & self, py::object const & indices) + { return pybind11::cast(self.take_along_axis(indices.cast())); }) .def_property_readonly("has_ghost", &wrapped_type::has_ghost) .def_property("nghost", &wrapped_type::nghost, &wrapped_type::set_nghost) .def_property_readonly("nbody", &wrapped_type::nbody) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index d95edc74..e69ffcec 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -831,8 +831,12 @@ def test_sort(self): for i in range(1, len(args)): self.assertLessEqual(sarr[args[i]], sarr[args[i]]) + sorted_arr = sarr.take_along_axis(args) + for i in range(1, len(sorted_arr)): + self.assertLessEqual(sorted_arr[i - 1], sorted_arr[i]) + sarr.sort() - for i in range(1, len(args)): + for i in range(1, len(sarr)): self.assertLessEqual(sarr[i - 1], sarr[i]) From fd45b60b5ad4f3280e4971c4176b747e4aeb7238 Mon Sep 17 00:00:00 2001 From: khlee529 Date: Wed, 5 Feb 2025 19:28:30 +0800 Subject: [PATCH 13/19] Refactor sorting function --- cpp/modmesh/buffer/SimpleArray.hpp | 76 +++++++++++++++--------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index b3b3b452..94ef5453 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -196,9 +196,24 @@ class SimpleArrayMixinCalculators } return ret; } +}; /* end class SimpleArrayMixinCalculators */ + +template +class SimpleArrayMixinSorters +{ + +private: + + using internal_types = detail::SimpleArrayInternalTypes; + +public: + + using value_type = typename internal_types::value_type; void sort(void); SimpleArray argsort(void); + template + A take_along_axis(SimpleArray const & indices); }; /* end class SimpleArrayMixinCalculators */ @@ -213,6 +228,7 @@ template class SimpleArray : public detail::SimpleArrayMixinModifiers, T> , public detail::SimpleArrayMixinCalculators, T> + , public detail::SimpleArrayMixinSorters, T> { private: @@ -593,9 +609,6 @@ class SimpleArray } } - template - SimpleArray take_along_axis(SimpleArray const & indices); - template value_type const & operator()(Args... args) const { return *vptr(args...); } template @@ -740,41 +753,8 @@ class SimpleArray value_type * m_body = nullptr; }; /* end class SimpleArray */ -template -template -SimpleArray SimpleArray::take_along_axis(SimpleArray const & indices) -{ - static_assert(std::is_integral_v, "I must be integral type"); - if (indices.ndim() != 1) - { - throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports only " - "in 1D array but the index array is " - << indices.ndim() << " dimension"); - } - if (ndim() != 1) - { - throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports only " - "in 1D array but the array is " - << ndim() << " dimension"); - } - if (shape()[0] != indices.shape()[0]) - { - throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports same " - "shape of indices and array. Array size " - << shape()[0] << " != indices array size " << indices.shape()[0]); - } - - SimpleArray ret(*this); - for (size_t i = 0; i < shape()[0]; ++i) - { - ret.at(i) = at(static_cast(indices[i])); - } - - return ret; -} - template -void detail::SimpleArrayMixinCalculators::sort(void) +void detail::SimpleArrayMixinSorters::sort(void) { auto athis = static_cast(this); if (athis->ndim() != 1) @@ -788,7 +768,7 @@ void detail::SimpleArrayMixinCalculators::sort(void) } template -SimpleArray detail::SimpleArrayMixinCalculators::argsort(void) +SimpleArray detail::SimpleArrayMixinSorters::argsort(void) { auto athis = static_cast(this); if (athis->ndim() != 1) @@ -815,6 +795,26 @@ SimpleArray detail::SimpleArrayMixinCalculators::argsort(void) return ret; } +template +template +A detail::SimpleArrayMixinSorters::take_along_axis(SimpleArray const & indices) +{ + static_assert(std::is_integral_v, "I must be integral type"); + auto athis = static_cast(this); + if (athis->ndim() != 1) + { + throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports currently only in 1D array " + " but the array is " << athis->ndim() << " dimension"); + } + + SimpleArray ret(indices.shape()); + std::transform(indices.begin(), indices.end(), ret.begin(), + [athis](I idx) + { return athis->at(static_cast(idx)); }); + + return ret; +} + template using is_simple_array = std::is_same< std::remove_reference_t, From 8944cb2b5af7ef5a2ddb553c19dbeb8f8ccf773a Mon Sep 17 00:00:00 2001 From: khlee529 Date: Wed, 5 Feb 2025 19:41:09 +0800 Subject: [PATCH 14/19] Refine sorting test --- tests/test_buffer.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index e69ffcec..e7bf1b5c 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -825,19 +825,26 @@ def test_SimpleArray_SimpleArrayPlex_type_switch(self): str(type(arrayplex_int32_2)), "") def test_sort(self): - narr = np.random.randint(0, 100, 20, dtype='int32') - sarr = modmesh.SimpleArrayInt32(array=narr) - args = sarr.argsort() - for i in range(1, len(args)): - self.assertLessEqual(sarr[args[i]], sarr[args[i]]) - - sorted_arr = sarr.take_along_axis(args) - for i in range(1, len(sorted_arr)): - self.assertLessEqual(sorted_arr[i - 1], sorted_arr[i]) - - sarr.sort() - for i in range(1, len(sarr)): - self.assertLessEqual(sarr[i - 1], sarr[i]) + test_data = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + [1, 5, 10, 2, 6, 9, 7, 8, 4, 3] + ] + + for arr in test_data: + narr = np.random.randint(arr, dtype='int32') + sarr = modmesh.SimpleArrayInt32(array=narr) + args = sarr.argsort() + for i in range(1, len(args)): + self.assertLessEqual(sarr[args[i]], sarr[args[i]]) + + sorted_arr = sarr.take_along_axis(args) + for i in range(1, len(sorted_arr)): + self.assertLessEqual(sorted_arr[i - 1], sorted_arr[i]) + + sarr.sort() + for i in range(1, len(sarr)): + self.assertLessEqual(sarr[i - 1], sarr[i]) class SimpleArrayCalculatorsTC(unittest.TestCase): From 1532106bf8714854c52da3fffd689d103349adef Mon Sep 17 00:00:00 2001 From: khlee529 Date: Wed, 5 Feb 2025 19:41:19 +0800 Subject: [PATCH 15/19] Reformat code --- cpp/modmesh/buffer/SimpleArray.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 94ef5453..5848f7f6 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -804,13 +804,13 @@ A detail::SimpleArrayMixinSorters::take_along_axis(SimpleArray const & if (athis->ndim() != 1) { throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis() supports currently only in 1D array " - " but the array is " << athis->ndim() << " dimension"); + " but the array is " + << athis->ndim() << " dimension"); } SimpleArray ret(indices.shape()); - std::transform(indices.begin(), indices.end(), ret.begin(), - [athis](I idx) - { return athis->at(static_cast(idx)); }); + std::transform(indices.begin(), indices.end(), ret.begin(), [athis](I idx) + { return athis->at(static_cast(idx)); }); return ret; } From 698b7f390de7ae973630c14cc0dd1509be31710c Mon Sep 17 00:00:00 2001 From: khlee529 Date: Wed, 5 Feb 2025 19:49:17 +0800 Subject: [PATCH 16/19] Fix wrong function call --- tests/test_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index e7bf1b5c..af27a6e6 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -832,7 +832,7 @@ def test_sort(self): ] for arr in test_data: - narr = np.random.randint(arr, dtype='int32') + narr = np.array(arr, dtype='int32') sarr = modmesh.SimpleArrayInt32(array=narr) args = sarr.argsort() for i in range(1, len(args)): From 0821907a3bd5757087bc74058996a1e745602068 Mon Sep 17 00:00:00 2001 From: khlee529 Date: Thu, 6 Feb 2025 00:16:29 +0800 Subject: [PATCH 17/19] Refactor sort mixin --- cpp/modmesh/buffer/SimpleArray.hpp | 20 ++++++++------ cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 26 ++++++++++++++----- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 5848f7f6..0c0a28da 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -199,7 +199,7 @@ class SimpleArrayMixinCalculators }; /* end class SimpleArrayMixinCalculators */ template -class SimpleArrayMixinSorters +class SimpleArrayMixinSort { private: @@ -215,7 +215,8 @@ class SimpleArrayMixinSorters template A take_along_axis(SimpleArray const & indices); -}; /* end class SimpleArrayMixinCalculators */ +}; /* end class SimpleArrayMixinSort */ + } /* end namespace detail */ @@ -228,7 +229,7 @@ template class SimpleArray : public detail::SimpleArrayMixinModifiers, T> , public detail::SimpleArrayMixinCalculators, T> - , public detail::SimpleArrayMixinSorters, T> + , public detail::SimpleArrayMixinSort, T> { private: @@ -754,7 +755,7 @@ class SimpleArray }; /* end class SimpleArray */ template -void detail::SimpleArrayMixinSorters::sort(void) +void detail::SimpleArrayMixinSort::sort(void) { auto athis = static_cast(this); if (athis->ndim() != 1) @@ -768,7 +769,7 @@ void detail::SimpleArrayMixinSorters::sort(void) } template -SimpleArray detail::SimpleArrayMixinSorters::argsort(void) +SimpleArray detail::SimpleArrayMixinSort::argsort(void) { auto athis = static_cast(this); if (athis->ndim() != 1) @@ -797,7 +798,7 @@ SimpleArray detail::SimpleArrayMixinSorters::argsort(void) template template -A detail::SimpleArrayMixinSorters::take_along_axis(SimpleArray const & indices) +A detail::SimpleArrayMixinSort::take_along_axis(SimpleArray const & indices) { static_assert(std::is_integral_v, "I must be integral type"); auto athis = static_cast(this); @@ -809,9 +810,12 @@ A detail::SimpleArrayMixinSorters::take_along_axis(SimpleArray const & } SimpleArray ret(indices.shape()); - std::transform(indices.begin(), indices.end(), ret.begin(), [athis](I idx) - { return athis->at(static_cast(idx)); }); + auto val_iter = ret.begin(); + for (auto idx: indices){ + *val_iter = athis->at(static_cast(idx)); + ++val_iter; + } return ret; } diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index 0fef6396..fa5ad4ad 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -147,13 +147,6 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray "reshape", [](wrapped_type const & self, py::object const & shape) { return self.reshape(make_shape(shape)); }) - .def( - "argsort", - [](wrapped_type & self) - { return pybind11::cast(self.argsort()); }) - .def("take_along_axis", - [](wrapped_type & self, py::object const & indices) - { return pybind11::cast(self.take_along_axis(indices.cast())); }) .def_property_readonly("has_ghost", &wrapped_type::has_ghost) .def_property("nghost", &wrapped_type::nghost, &wrapped_type::set_nghost) .def_property_readonly("nbody", &wrapped_type::nbody) @@ -161,6 +154,7 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray { return pybind11::cast(SimpleArrayPlex(arr)); }) .wrap_modifiers() .wrap_calculators() + .wrap_sort() // ATTENTION: always keep the same interface between WrapSimpleArrayPlex and WrapSimpleArray ; } @@ -192,7 +186,25 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray .def("max", &wrapped_type::max) .def("sum", &wrapped_type::sum) .def("abs", &wrapped_type::abs) + // + ; + + return *this; + } + + wrapper_type & wrap_sort() + { + namespace py = pybind11; // NOLINT(misc-unused-alias-decls) + + (*this) .def("sort", &wrapped_type::sort) + .def( + "argsort", + [](wrapped_type & self) + { return pybind11::cast(self.argsort()); }) + .def("take_along_axis", + [](wrapped_type & self, py::object const & indices) + { return pybind11::cast(self.take_along_axis(indices.cast())); }) // ; From ad1dc262bdb21b007a752ed1058cf4e0cda75d57 Mon Sep 17 00:00:00 2001 From: khlee529 Date: Thu, 6 Feb 2025 00:47:08 +0800 Subject: [PATCH 18/19] Add test case to sort test --- tests/test_buffer.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index af27a6e6..86eeef84 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -828,12 +828,19 @@ def test_sort(self): test_data = [ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], - [1, 5, 10, 2, 6, 9, 7, 8, 4, 3] + [1, 5, 10, 2, 6, 9, 7, 8, 4, 3], + [1, 0, 1, -3, -4, -1, 1, 9, 5, -4], + [-1.3, -4.8, 1.5, 0.3, 7.1, 2.5, 4.8, -0.1, 9.4, 7.6] ] - for arr in test_data: - narr = np.array(arr, dtype='int32') - sarr = modmesh.SimpleArrayInt32(array=narr) + def _check(arr, use_float=False): + if use_float: + narr = np.array(arr, dtype='float64') + sarr = modmesh.SimpleArrayFloat64(array=narr) + else: + narr = np.array(arr, dtype='int32') + sarr = modmesh.SimpleArrayInt32(array=narr) + args = sarr.argsort() for i in range(1, len(args)): self.assertLessEqual(sarr[args[i]], sarr[args[i]]) @@ -846,6 +853,12 @@ def test_sort(self): for i in range(1, len(sarr)): self.assertLessEqual(sarr[i - 1], sarr[i]) + _check(test_data[0]) + _check(test_data[1]) + _check(test_data[2]) + _check(test_data[3]) + _check(test_data[4], True) + class SimpleArrayCalculatorsTC(unittest.TestCase): From 816dadc3729087977908dff4d0698995b7d0d9cf Mon Sep 17 00:00:00 2001 From: khlee529 Date: Thu, 6 Feb 2025 00:52:57 +0800 Subject: [PATCH 19/19] Reformat code --- cpp/modmesh/buffer/SimpleArray.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 0c0a28da..65ea6b87 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -217,7 +217,6 @@ class SimpleArrayMixinSort }; /* end class SimpleArrayMixinSort */ - } /* end namespace detail */ /** @@ -812,7 +811,8 @@ A detail::SimpleArrayMixinSort::take_along_axis(SimpleArray const & ind SimpleArray ret(indices.shape()); auto val_iter = ret.begin(); - for (auto idx: indices){ + for (auto idx : indices) + { *val_iter = athis->at(static_cast(idx)); ++val_iter; }