-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement sort
and argsort
for one dimensional SimpleArray
in C++
#456
base: master
Are you sure you want to change the base?
Changes from 10 commits
bb2f6bd
a381336
94d78d6
00bc2b6
3ac729a
6a7f163
6dd619d
1f59f92
1998843
45b2d4d
0e803b5
111175e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
#include <stdexcept> | ||
#include <functional> | ||
#include <numeric> | ||
#include <algorithm> | ||
|
||
#if defined(_MSC_VER) | ||
#include <BaseTsd.h> | ||
|
@@ -192,6 +193,18 @@ class SimpleArrayMixinCalculators | |
} | ||
return ret; | ||
} | ||
|
||
void sort(void) | ||
{ | ||
auto athis = static_cast<A *>(this); | ||
if (athis->ndim() != 1) | ||
{ | ||
throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " | ||
<< athis->ndim() << "D array is currently not supported"); | ||
} | ||
|
||
std::sort(athis->begin(), athis->end()); | ||
} | ||
}; /* end class SimpleArrayMixinCalculators */ | ||
|
||
} /* end namespace detail */ | ||
|
@@ -585,6 +598,33 @@ class SimpleArray | |
} | ||
} | ||
|
||
SimpleArray<uint64_t> argsort(void) | ||
{ | ||
if (ndim() != 1) | ||
{ | ||
throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " | ||
<< ndim() << "D array is currently not supported"); | ||
} | ||
|
||
SimpleArray<uint64_t> ret(shape()); | ||
|
||
{ // Return array initialization | ||
uint64_t cnt = 0; | ||
std::for_each(ret.begin(), ret.end(), [&cnt](uint64_t & v) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initialize the array to be returned to the ascending indices |
||
{ v = cnt++; }); | ||
} | ||
|
||
value_type const * buf = body(); | ||
auto cmp = [buf](uint64_t a, uint64_t b) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comparison function that take the input as indices of the SimpleArray and compare the value of the SimpleArray at the given index. |
||
{ | ||
return buf[a] < buf[b]; | ||
}; | ||
std::sort(ret.begin(), ret.end(), cmp); | ||
return ret; | ||
} | ||
|
||
void take_along_axis(SimpleArray<uint64_t> const & indices); | ||
|
||
template <typename... Args> | ||
value_type const & operator()(Args... args) const { return *vptr(args...); } | ||
template <typename... Args> | ||
|
@@ -729,6 +769,64 @@ class SimpleArray | |
value_type * m_body = nullptr; | ||
}; /* end class SimpleArray */ | ||
|
||
template <typename T> | ||
void SimpleArray<T>::take_along_axis(SimpleArray<uint64_t> const & indices) | ||
{ | ||
if (indices.ndim() != 1) | ||
{ | ||
throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please reword to be like |
||
<< indices.ndim() << "D indices is not supported."); | ||
} | ||
if (ndim() != 1) | ||
{ | ||
throw std::runtime_error(Formatter() << "SimpleArray: sorting is only supported in 1D array. " | ||
<< ndim() << "D array is not supported to be sorted."); | ||
} | ||
if (shape()[0] != indices.shape()[0]) | ||
{ | ||
throw std::runtime_error(Formatter() << "SimpleArray: take_along_axis only support same shape of indices and array." | ||
<< "Array size " << shape()[0] << " != indices shape " << indices.shape()[0]); | ||
} | ||
if (shape()[0] < 2) | ||
{ | ||
return; | ||
} | ||
|
||
std::vector<bool> applied_arg(shape()[0], false); | ||
|
||
auto next = [](std::vector<bool> & vec, ssize_t last) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think although the time complexity of the By the way, here is the analysis of the gpt-o3-mini. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is prefered to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree the overall complexity should still be @KHLee529 , please avoid the lambda. It will make the code clearer and better optimized. The use of lambda also makes it difficult if not impossible for SIMD. Please also avoid bound-checking as much as possible. Is there a way to reduce the loop and/or avoid the nesting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I notice that the space complexity of this function is O(n). If we directly create the new buffer |
||
{ | ||
for (ssize_t i = last; i < static_cast<ssize_t>(vec.size()); ++i) | ||
{ | ||
if (vec.at(i) == false) | ||
{ | ||
return i; | ||
} | ||
} | ||
return static_cast<ssize_t>(-1); | ||
}; | ||
|
||
ssize_t idx = 0; | ||
while ((idx = next(applied_arg, idx)) != -1) | ||
{ | ||
value_type val = at(idx); | ||
|
||
ssize_t dst_idx = idx; | ||
ssize_t src_idx = indices[dst_idx]; | ||
|
||
while (src_idx != idx) | ||
{ | ||
at(dst_idx) = at(src_idx); | ||
applied_arg.at(dst_idx) = true; | ||
dst_idx = src_idx; | ||
src_idx = indices[dst_idx]; | ||
} | ||
|
||
at(dst_idx) = val; | ||
applied_arg.at(dst_idx) = true; | ||
} | ||
} | ||
|
||
template <typename S> | ||
using is_simple_array = std::is_same< | ||
std::remove_reference_t<S>, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -824,6 +824,17 @@ def test_SimpleArray_SimpleArrayPlex_type_switch(self): | |
self.assertEqual( | ||
str(type(arrayplex_int32_2)), "<class '_modmesh.SimpleArray'>") | ||
|
||
def test_sort(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test |
||
narr = np.random.randint(0, 100, 20, dtype='int32') | ||
sarr = modmesh.SimpleArrayInt32(array=narr) | ||
args = sarr.argsort() | ||
for i in range(1, len(args)): | ||
self.assertLessEqual(sarr[args[i]], sarr[args[i]]) | ||
|
||
sarr.sort() | ||
for i in range(1, len(args)): | ||
self.assertLessEqual(sarr[i - 1], sarr[i]) | ||
|
||
|
||
class SimpleArrayCalculatorsTC(unittest.TestCase): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function that returns a SimpleArray containing the indices which will sort the array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argsort()
is of the same group assort()
and should also go toSimpleArrayMixinCalculators
.In PR we do not add comments for every detail. Please be comprehensive when addressing review comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And the function is too long and should be moved outside the class declaration.