diff --git a/hpx-sys/include/wrapper.h b/hpx-sys/include/wrapper.h index 5aad2ec..42e2af1 100644 --- a/hpx-sys/include/wrapper.h +++ b/hpx-sys/include/wrapper.h @@ -180,3 +180,35 @@ inline void hpx_merge(const rust::Vec& src1, dest.push_back(item); } } + +inline void hpx_partial_sort(rust::Vec& src, size_t last) { + std::vector cpp_vec(src.begin(), src.end()); + + hpx::partial_sort(hpx::execution::par, + cpp_vec.begin(), + cpp_vec.begin() + last, + cpp_vec.end()); + + src.clear(); + src.reserve(cpp_vec.size()); + for (const auto& item : cpp_vec) { + src.push_back(item); + } +} + +inline void hpx_partial_sort_comp(rust::Vec& src, size_t last, + rust::Fn comp) { + std::vector cpp_vec(src.begin(), src.end()); + + hpx::partial_sort(hpx::execution::par, + cpp_vec.begin(), + cpp_vec.begin() + last, + cpp_vec.end(), + [&](int32_t a, int32_t b) { return comp(a, b); }); + + src.clear(); + src.reserve(cpp_vec.size()); + for (const auto& item : cpp_vec) { + src.push_back(item); + } +} diff --git a/hpx-sys/src/lib.rs b/hpx-sys/src/lib.rs index badcf7a..c6f4aaa 100644 --- a/hpx-sys/src/lib.rs +++ b/hpx-sys/src/lib.rs @@ -30,6 +30,8 @@ pub mod ffi { fn hpx_sort(src: &mut Vec); fn hpx_sort_comp(src: &mut Vec, comp: fn(i32, i32) -> bool); fn hpx_merge(src1: &Vec, src2: &Vec, dest: &mut Vec); + fn hpx_partial_sort(src: &mut Vec, last: usize); + fn hpx_partial_sort_comp(src: &mut Vec, last: usize, comp: fn(i32, i32) -> bool); } } @@ -421,4 +423,64 @@ mod tests { assert_eq!(result, 0); } } + + #[test] + #[serial] + fn test_hpx_partial_sort() { + let (argc, mut argv) = create_c_args(&["test_hpx_partial_sort"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let mut vec = vec![5, 2, 8, 1, 9, 3, 7, 6, 4]; + let last = 4; + println!("Before partial sort: {:?}", vec); + + ffi::hpx_partial_sort(&mut vec, last); + println!("After partial sort: {:?}", vec); + + // If first -> last elements are sorted + assert!(vec[..last].windows(2).all(|w| w[0] <= w[1])); + + // If ele of sorted part <= ele of unsorted part + assert!(vec[..last] + .iter() + .all(|&x| vec[last..].iter().all(|&y| x <= y))); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } + + #[test] + #[serial] + fn test_hpx_partial_sort_comp() { + let (argc, mut argv) = create_c_args(&["test_hpx_partial_sort_comp"]); + + let hpx_main = |_argc: i32, _argv: *mut *mut c_char| -> i32 { + let mut vec = vec![5, 2, 8, 1, 9, 3, 7, 6, 4]; + let last = 4; + println!("Before partial sort: {:?}", vec); + + ffi::hpx_partial_sort_comp(&mut vec, last, |a, b| b < a); + println!("After partial sort: {:?}", vec); + + // If first -> last elements are sorted dec + assert!(vec[..last].windows(2).all(|w| w[0] >= w[1])); + + // If ele of sorted part >= ele of unsorted part + assert!(vec[..last] + .iter() + .all(|&x| vec[last..].iter().all(|&y| x >= y))); + + ffi::finalize() + }; + + unsafe { + let result = ffi::init(hpx_main, argc, argv.as_mut_ptr()); + assert_eq!(result, 0); + } + } }