diff --git a/source/binary/Add.hpp b/source/binary/Add.hpp index a08fd70..7761f80 100644 --- a/source/binary/Add.hpp +++ b/source/binary/Add.hpp @@ -15,89 +15,260 @@ namespace Langulus::SIMD { /// Used to detect missing SIMD routine - LANGULUS(INLINED) + template LANGULUS(INLINED) constexpr Unsupported AddSIMD(CT::NotSIMD auto, CT::NotSIMD auto) noexcept { return {}; } /// Add two registers + /// @tparam SATURATE - whether to clamp to [min;max] /// @param lhs - left register /// @param rhs - right register /// @return the resulting register - template LANGULUS(INLINED) - R AddSIMD(R lhs, R rhs) noexcept { + template LANGULUS(INLINED) + auto AddSIMD(R lhs, R rhs) noexcept { using T = TypeOf; (void)lhs; (void)rhs; - if constexpr (CT::SIMD128) { - if constexpr (CT::SignedInteger8) return simde_mm_add_epi8 (lhs, rhs); - else if constexpr (CT::UnsignedInteger8) return simde_mm_adds_epu8 (lhs, rhs); - else if constexpr (CT::SignedInteger16) return simde_mm_add_epi16 (lhs, rhs); - else if constexpr (CT::UnsignedInteger16) return simde_mm_adds_epu16 (lhs, rhs); - else if constexpr (CT::Integer32) return simde_mm_add_epi32 (lhs, rhs); - else if constexpr (CT::Integer64) return simde_mm_add_epi64 (lhs, rhs); - else if constexpr (CT::Float) return simde_mm_add_ps (lhs, rhs); - else if constexpr (CT::Double) return simde_mm_add_pd (lhs, rhs); - else static_assert(false, "Unsupported type for 16-byte package"); + if constexpr (SATURATE) { + if constexpr (CT::SIMD128) { + if constexpr (CT::SignedInteger8) return R {simde_mm_adds_epi8 (lhs, rhs)}; + else if constexpr (CT::UnsignedInteger8) return R {simde_mm_adds_epu8 (lhs, rhs)}; + else if constexpr (CT::SignedInteger16) return R {simde_mm_adds_epi16 (lhs, rhs)}; + else if constexpr (CT::UnsignedInteger16) return R {simde_mm_adds_epu16 (lhs, rhs)}; + else if constexpr (CT::SignedInteger32) { + // https://stackoverflow.com/questions/29498824 + const auto int_max = simde_mm_set1_epi32(::std::numeric_limits::max()); + const auto res = simde_mm_add_epi32 (lhs, rhs); + const auto sign_bit = simde_mm_srli_epi32(lhs, 31); + #if LANGULUS_SIMD(AVX512VL) + const auto overflow = simde_mm_ternarylogic_epi32(lhs, rhs, res, 0x42); + #else + const auto sign_xor = simde_mm_xor_si128(lhs, rhs); + const auto overflow = simde_mm_andnot_si128(sign_xor, simde_mm_xor_si128(lhs, res)); + #endif + + #if LANGULUS_SIMD(AVX512DQ) and LANGULUS_SIMD(AVX512VL) + return R {simde_mm_mask_add_epi32(res, simde_mm_movepi32_mask(overflow), int_max, sign_bit)}; + #else + const auto saturated = simde_mm_add_epi32(int_max, sign_bit); + return R {simde_mm_castps_si128(simde_mm_blendv_ps( + simde_mm_castsi128_ps(res), + simde_mm_castsi128_ps(saturated), + simde_mm_castsi128_ps(overflow) + ))}; + #endif + } + else if constexpr (CT::UnsignedInteger32) { + const auto mx = simde_mm_min_epu32(lhs, not rhs); + return R {simde_mm_add_epi32(mx, rhs)}; + } + else if constexpr (CT::SignedInteger64) + return R {simde_mm_add_epi64(lhs, rhs)}; + else if constexpr (CT::UnsignedInteger64) { + #if LANGULUS_SIMD(AVX512F) and LANGULUS_SIMD(AVX512VL) + const auto mx = simde_mm_min_epu64(lhs, not rhs); + return R {simde_mm_add_epi64(mx, rhs)}; + #else + return Unsupported {}; + #endif + } + else if constexpr (CT::Float) { + // Clamp to [0;1] range + return R {simde_mm_max_ps(simde_mm_min_ps( + simde_mm_add_ps(lhs, rhs), simde_mm_set1_ps(1)), simde_mm_set1_ps(0))}; + } + else if constexpr (CT::Double) { + // Clamp to [0;1] range + return R {simde_mm_max_pd(simde_mm_min_pd( + simde_mm_add_pd(lhs, rhs), simde_mm_set1_pd(1)), simde_mm_set1_pd(0))}; + } + else static_assert(false, "Unsupported type for 16-byte package"); + } + else if constexpr (CT::SIMD256) { + if constexpr (CT::SignedInteger8) return R {simde_mm256_adds_epi8 (lhs, rhs)}; + else if constexpr (CT::UnsignedInteger8) return R {simde_mm256_adds_epu8 (lhs, rhs)}; + else if constexpr (CT::SignedInteger16) return R {simde_mm256_adds_epi16 (lhs, rhs)}; + else if constexpr (CT::UnsignedInteger16) return R {simde_mm256_adds_epu16 (lhs, rhs)}; + else if constexpr (CT::SignedInteger32) { + // https://stackoverflow.com/questions/29498824 + const auto int_max = simde_mm256_set1_epi32(::std::numeric_limits::max()); + const auto res = simde_mm256_add_epi32 (lhs, rhs); + const auto sign_bit = simde_mm256_srli_epi32(lhs, 31); + #if LANGULUS_SIMD(AVX512VL) + const auto overflow = simde_mm256_ternarylogic_epi32(lhs, rhs, res, 0x42); + #else + const auto sign_xor = simde_mm256_xor_si256(lhs, rhs); + const auto overflow = simde_mm256_andnot_si256(sign_xor, simde_mm256_xor_si256(lhs, res)); + #endif + + #if LANGULUS_SIMD(AVX512DQ) and LANGULUS_SIMD(AVX512VL) + return R {simde_mm256_mask_add_epi32(res, simde_mm256_movepi32_mask(overflow), int_max, sign_bit)}; + #else + const auto saturated = simde_mm256_add_epi32(int_max, sign_bit); + return R {simde_mm256_castps_si256(simde_mm256_blendv_ps( + simde_mm256_castsi256_ps(res), + simde_mm256_castsi256_ps(saturated), + simde_mm256_castsi256_ps(overflow) + ))}; + #endif + } + else if constexpr (CT::UnsignedInteger32) { + const auto mx = simde_mm256_min_epu32(lhs, not rhs); + return R {simde_mm256_add_epi32(mx, rhs)}; + } + else if constexpr (CT::SignedInteger64) + return R {simde_mm256_add_epi64(lhs, rhs)}; + else if constexpr (CT::UnsignedInteger64) { + #if LANGULUS_SIMD(AVX512F) and LANGULUS_SIMD(AVX512VL) + const auto mx = simde_mm256_min_epu64(lhs, not rhs); + return R {simde_mm256_add_epi64(mx, rhs)}; + #else + return Unsupported {}; + #endif + } + else if constexpr (CT::Float) { + // Clamp to [0;1] range + return R {simde_mm256_max_ps(simde_mm256_min_ps( + simde_mm256_add_ps(lhs, rhs), simde_mm256_set1_ps(1)), simde_mm256_set1_ps(0))}; + } + else if constexpr (CT::Double) { + // Clamp to [0;1] range + return R {simde_mm256_max_pd(simde_mm256_min_pd( + simde_mm256_add_pd(lhs, rhs), simde_mm256_set1_pd(1)), simde_mm256_set1_pd(0))}; + } + else static_assert(false, "Unsupported type for 32-byte package"); + } + else if constexpr (CT::SIMD512) { + if constexpr (CT::SignedInteger8) return R {simde_mm512_adds_epi8 (lhs, rhs)}; + else if constexpr (CT::UnsignedInteger8) return R {simde_mm512_adds_epu8 (lhs, rhs)}; + else if constexpr (CT::SignedInteger16) return R {simde_mm512_adds_epi16 (lhs, rhs)}; + else if constexpr (CT::UnsignedInteger16) return R {simde_mm512_adds_epu16 (lhs, rhs)}; + else if constexpr (CT::SignedInteger32){ + // https://stackoverflow.com/questions/29498824 + const auto int_max = simde_mm512_set1_epi32(::std::numeric_limits::max()); + const auto res = simde_mm512_add_epi32 (lhs, rhs); + const auto sign_bit = simde_mm512_srli_epi32(lhs, 31); + const auto overflow = simde_mm512_ternarylogic_epi32(lhs, rhs, res, 0x42); + return R {simde_mm512_mask_add_epi32(res, simde_mm512_movepi32_mask(overflow), int_max, sign_bit)}; + } + else if constexpr (CT::UnsignedInteger32) { + const auto mx = simde_mm512_min_epu32(lhs, not rhs); + return R {simde_mm512_add_epi32(mx, rhs)}; + } + else if constexpr (CT::SignedInteger64) + return R {simde_mm512_add_epi64(lhs, rhs)}; + else if constexpr (CT::UnsignedInteger64) { + const auto mx = simde_mm512_min_epu64(lhs, not rhs); + return R {simde_mm512_add_epi64(mx, rhs)}; + } + else if constexpr (CT::Float) { + // Clamp to [0;1] range + return R {simde_mm512_max_ps(simde_mm512_min_ps( + simde_mm512_add_ps(lhs, rhs), simde_mm512_set1_ps(1)), simde_mm512_set1_ps(0))}; + } + else if constexpr (CT::Double) { + // Clamp to [0;1] range + return R {simde_mm512_max_pd(simde_mm512_min_pd( + simde_mm512_add_pd(lhs, rhs), simde_mm512_set1_pd(1)), simde_mm512_set1_pd(0))}; + } + else static_assert(false, "Unsupported type for 64-byte package"); + } + else static_assert(false, "Unsupported type"); } - else if constexpr (CT::SIMD256) { - if constexpr (CT::SignedInteger8) return simde_mm256_add_epi8 (lhs, rhs); - else if constexpr (CT::UnsignedInteger8) return simde_mm256_adds_epu8 (lhs, rhs); - else if constexpr (CT::SignedInteger16) return simde_mm256_add_epi16 (lhs, rhs); - else if constexpr (CT::UnsignedInteger16) return simde_mm256_adds_epu16 (lhs, rhs); - else if constexpr (CT::Integer32) return simde_mm256_add_epi32 (lhs, rhs); - else if constexpr (CT::Integer64) return simde_mm256_add_epi64 (lhs, rhs); - else if constexpr (CT::Float) return simde_mm256_add_ps (lhs, rhs); - else if constexpr (CT::Double) return simde_mm256_add_pd (lhs, rhs); - else static_assert(false, "Unsupported type for 32-byte package"); + else { + if constexpr (CT::SIMD128) { + if constexpr (CT::Integer8) return R {simde_mm_add_epi8 (lhs, rhs)}; + else if constexpr (CT::Integer16) return R {simde_mm_add_epi16 (lhs, rhs)}; + else if constexpr (CT::Integer32) return R {simde_mm_add_epi32 (lhs, rhs)}; + else if constexpr (CT::Integer64) return R {simde_mm_add_epi64 (lhs, rhs)}; + else if constexpr (CT::Float) return R {simde_mm_add_ps (lhs, rhs)}; + else if constexpr (CT::Double) return R {simde_mm_add_pd (lhs, rhs)}; + else static_assert(false, "Unsupported type for 16-byte package"); + } + else if constexpr (CT::SIMD256) { + if constexpr (CT::Integer8) return R {simde_mm256_add_epi8 (lhs, rhs)}; + else if constexpr (CT::Integer16) return R {simde_mm256_add_epi16 (lhs, rhs)}; + else if constexpr (CT::Integer32) return R {simde_mm256_add_epi32 (lhs, rhs)}; + else if constexpr (CT::Integer64) return R {simde_mm256_add_epi64 (lhs, rhs)}; + else if constexpr (CT::Float) return R {simde_mm256_add_ps (lhs, rhs)}; + else if constexpr (CT::Double) return R {simde_mm256_add_pd (lhs, rhs)}; + else static_assert(false, "Unsupported type for 32-byte package"); + } + else if constexpr (CT::SIMD512) { + if constexpr (CT::Integer8) return R {simde_mm512_add_epi8 (lhs, rhs)}; + else if constexpr (CT::Integer16) return R {simde_mm512_add_epi16 (lhs, rhs)}; + else if constexpr (CT::Integer32) return R {simde_mm512_add_epi32 (lhs, rhs)}; + else if constexpr (CT::Integer64) return R {simde_mm512_add_epi64 (lhs, rhs)}; + else if constexpr (CT::Float) return R {simde_mm512_add_ps (lhs, rhs)}; + else if constexpr (CT::Double) return R {simde_mm512_add_pd (lhs, rhs)}; + else static_assert(false, "Unsupported type for 64-byte package"); + } + else static_assert(false, "Unsupported type"); } - else if constexpr (CT::SIMD512) { - if constexpr (CT::SignedInteger8) return simde_mm512_add_epi8 (lhs, rhs); - else if constexpr (CT::UnsignedInteger8) return simde_mm512_adds_epu8 (lhs, rhs); - else if constexpr (CT::SignedInteger16) return simde_mm512_add_epi16 (lhs, rhs); - else if constexpr (CT::UnsignedInteger16) return simde_mm512_adds_epu16 (lhs, rhs); - else if constexpr (CT::Integer32) return simde_mm512_add_epi32 (lhs, rhs); - else if constexpr (CT::Integer64) return simde_mm512_add_epi64 (lhs, rhs); - else if constexpr (CT::Float) return simde_mm512_add_ps (lhs, rhs); - else if constexpr (CT::Double) return simde_mm512_add_pd (lhs, rhs); - else static_assert(false, "Unsupported type for 64-byte package"); + } + + /// Fallback addition + /// @tparam SATURATE - whether to clamp to max if overflow occurs + template LANGULUS(INLINED) + constexpr E AddFallback(const E& lhs, const E& rhs) noexcept { + if constexpr (SATURATE) { + using WIDER = WiderSigned; + + if constexpr (sizeof(WIDER) == sizeof(E) and CT::Integer) { + // If WIDER type isn't wider, perform the saturation + // by hand + constexpr E lo = ::std::numeric_limits::min(); + constexpr E hi = ::std::numeric_limits::max(); + + if (rhs > 0) + return lhs > hi - rhs ? hi : lhs + rhs; + else + return lhs < lo - rhs ? lo : lhs + rhs; + } + else if constexpr (CT::Integer) + return Saturate(static_cast(lhs) + static_cast(rhs)); + else + return Saturate(lhs + rhs); } - else static_assert(false, "Unsupported type"); + else return lhs + rhs; } - + /// Get sum of values as constexpr, if possible + /// @tparam SATURATE - whether to clamp to max if overflow occurs /// @tparam FORCE_OUT - the desired element type (lossless if void) /// @patam value - scalar/vector to operate on /// @return the summed scalar/vector - template LANGULUS(INLINED) + template LANGULUS(INLINED) constexpr auto AddConstexpr(const auto& lhs, const auto& rhs) noexcept { return AttemptBinary<0, FORCE_OUT>(lhs, rhs, nullptr, [](const E& l, const E& r) noexcept -> E { - return l + r; + return AddFallback(l, r); } ); } /// Get summed values as a register, if possible + /// @tparam SATURATE - whether to clamp to max if overflow occurs /// @tparam FORCE_OUT - the desired element type (lossless if void) /// @patam value - scalar/vector/register to operate on /// @return the summed scalar/vector/register - template LANGULUS(INLINED) + template LANGULUS(INLINED) auto Add(const auto& lhs, const auto& rhs) noexcept { return AttemptBinary<0, FORCE_OUT>(lhs, rhs, [](const R& l, const R& r) noexcept { LANGULUS_SIMD_VERBOSE("Adding (SIMD) as ", NameOf()); - return AddSIMD(l, r); + return AddSIMD(l, r); }, [](const E& l, const E& r) noexcept -> E { LANGULUS_SIMD_VERBOSE("Adding (Fallback) ", l, " + ", r, " (", NameOf(), ")"); - return l + r; + return AddFallback(l, r); } ); } } // namespace Langulus::SIMD::Inner - LANGULUS_SIMD_ARITHMETHIC_API(Add) + LANGULUS_SIMD_ARITHMETHIC_WITH_SATURATION_API(Add) } // namespace Langulus::SIMD diff --git a/source/binary/Subtract.hpp b/source/binary/Subtract.hpp index 2548c0b..70236d7 100644 --- a/source/binary/Subtract.hpp +++ b/source/binary/Subtract.hpp @@ -21,7 +21,7 @@ namespace Langulus::SIMD } /// Subtract two registers - /// @tparam SATURATE - whether to clamp to max if overflow occurs + /// @tparam SATURATE - whether to clamp to [min;max] /// @param lhs - left register /// @param rhs - right register /// @return the resulting register diff --git a/test/Add/TestAdd-VS.cpp b/test/Add/TestAdd-VS.cpp new file mode 100644 index 0000000..a9cd8a0 --- /dev/null +++ b/test/Add/TestAdd-VS.cpp @@ -0,0 +1,166 @@ +/// +/// Langulus::SIMD +/// Copyright (c) 2019 Dimo Markov +/// Part of the Langulus framework, see https://langulus.com +/// +/// SPDX-License-Identifier: MIT +/// +#include "TestAdd.hpp" + + +TEMPLATE_TEST_CASE("Vector + Scalar", "[add]" + , NUMBERS_ALL() + , VECTORS_ALL(1) + , VECTORS_ALL(2) + , VECTORS_ALL(3) + , VECTORS_ALL(4) + , VECTORS_ALL(5) + , VECTORS_ALL(8) + , VECTORS_ALL(9) + , VECTORS_ALL(16) + , VECTORS_ALL(17) + , VECTORS_ALL(32) + , VECTORS_ALL(33) +) { + using T = TestType; + using E = TypeOf; + static_assert(CountOf> == 2); + + GIVEN("Vector + Scalar = Vector") { + T x; + E y {}; + T r, rCheck; + + if constexpr (not CT::Vector) { + InitOne(x, 1); + InitOne(y, -5); + } + else InitOne(y, -5); + + WHEN("Added as constexpr (with saturation)") { + constexpr T lhs = E {0}; + constexpr E rhs = E {5}; + static_assert(SIMD::Add(lhs, rhs) == T {CT::Real> ? 1 : 5}); + } + + WHEN("Added as constexpr (without saturation)") { + constexpr T lhs = E {0}; + constexpr E rhs = E {5}; + static_assert(SIMD::Add(lhs, rhs) == static_cast(5)); + } + + WHEN("Added (with saturation)") { + ControlAdd(x, y, rCheck); + SIMD::Add(x, y, r); + + REQUIRE(r == rCheck); + + #ifdef LANGULUS_STD_BENCHMARK + BENCHMARK_ADVANCED("Add (control)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + ControlAdd(nx[i], ny[i], nr[i]); + }); + }; + + BENCHMARK_ADVANCED("Add (SIMD)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + if constexpr (CT::Vector) + SIMD::Add(nx[i].mArray, ny[i].mArray, nr[i].mArray); + else + SIMD::Add(nx[i], ny[i], nr[i]); + }); + }; + #endif + } + + WHEN("Added (without saturation)") { + ControlAdd(x, y, rCheck); + SIMD::Add(x, y, r); + + REQUIRE(r == rCheck); + + #ifdef LANGULUS_STD_BENCHMARK + BENCHMARK_ADVANCED("Add (control)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + ControlAdd(nx[i], ny[i], nr[i]); + }); + }; + + BENCHMARK_ADVANCED("Add (SIMD)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + if constexpr (CT::Vector) + SIMD::Add(nx[i].mArray, ny[i].mArray, nr[i].mArray); + else + SIMD::Add(nx[i], ny[i], nr[i]); + }); + }; + #endif + } + + WHEN("Added in reverse (with saturation)") { + ControlAdd(y, x, rCheck); + SIMD::Add(y, x, r); + + REQUIRE(r == rCheck); + } + + WHEN("Added in reverse (without saturation)") { + ControlAdd(y, x, rCheck); + SIMD::Add(y, x, r); + + REQUIRE(r == rCheck); + } + } +} \ No newline at end of file diff --git a/test/Add/TestAdd-VV.cpp b/test/Add/TestAdd-VV.cpp new file mode 100644 index 0000000..4a7e5be --- /dev/null +++ b/test/Add/TestAdd-VV.cpp @@ -0,0 +1,162 @@ +/// +/// Langulus::SIMD +/// Copyright (c) 2019 Dimo Markov +/// Part of the Langulus framework, see https://langulus.com +/// +/// SPDX-License-Identifier: MIT +/// +#include "TestAdd.hpp" + + +TEMPLATE_TEST_CASE("Vector + Vector", "[add]" + , NUMBERS_ALL() + , VECTORS_ALL(1) + , VECTORS_ALL(2) + , VECTORS_ALL(3) + , VECTORS_ALL(4) + , VECTORS_ALL(5) + , VECTORS_ALL(8) + , VECTORS_ALL(9) + , VECTORS_ALL(16) + , VECTORS_ALL(17) + , VECTORS_ALL(32) + , VECTORS_ALL(33) +) { + using T = TestType; + + GIVEN("x * y = r") { + T x, y; + T r, rCheck; + + if constexpr (not CT::Vector) { + InitOne(x, 1); + InitOne(y, -5); + } + + WHEN("Added as constexpr (with saturation)") { + constexpr T lhs {0}; + constexpr T rhs {5}; + static_assert(SIMD::Add(lhs, rhs) == T {CT::Real> ? 1 : 5}); + } + + WHEN("Added as constexpr (without saturation)") { + constexpr T lhs {0}; + constexpr T rhs {5}; + static_assert(SIMD::Add(lhs, rhs) == static_cast(5)); + } + + WHEN("Added (with saturation)") { + ControlAdd(x, y, rCheck); + SIMD::Add(x, y, r); + + REQUIRE(r == rCheck); + + #ifdef LANGULUS_STD_BENCHMARK + BENCHMARK_ADVANCED("Add (control)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + ControlAdd(nx[i], ny[i], nr[i]); + }); + }; + + BENCHMARK_ADVANCED("Add (SIMD)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + if constexpr (CT::Vector) + SIMD::Add(nx[i].mArray, ny[i].mArray, nr[i].mArray); + else + SIMD::Add(nx[i], ny[i], nr[i]); + }); + }; + #endif + } + + WHEN("Added (without saturation)") { + ControlAdd(x, y, rCheck); + SIMD::Add(x, y, r); + + REQUIRE(r == rCheck); + + #ifdef LANGULUS_STD_BENCHMARK + BENCHMARK_ADVANCED("Add (control)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + ControlAdd(nx[i], ny[i], nr[i]); + }); + }; + + BENCHMARK_ADVANCED("Add (SIMD)") (timer meter) { + some nx(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : nx) + InitOne(i, 1); + } + + some ny(meter.runs()); + if constexpr (not CT::Vector) { + for (auto& i : ny) + InitOne(i, 1); + } + + some nr(meter.runs()); + meter.measure([&](int i) { + if constexpr (CT::Vector) + SIMD::Add(nx[i].mArray, ny[i].mArray, nr[i].mArray); + else + SIMD::Add(nx[i], ny[i], nr[i]); + }); + }; + #endif + } + + WHEN("Added in reverse (with saturation)") { + ControlAdd(y, x, rCheck); + SIMD::Add(y, x, r); + + REQUIRE(r == rCheck); + } + + WHEN("Added in reverse (without saturation)") { + ControlAdd(y, x, rCheck); + SIMD::Add(y, x, r); + + REQUIRE(r == rCheck); + } + } +} diff --git a/test/Add/TestAdd.hpp b/test/Add/TestAdd.hpp new file mode 100644 index 0000000..fe9dfe7 --- /dev/null +++ b/test/Add/TestAdd.hpp @@ -0,0 +1,51 @@ +/// +/// Langulus::SIMD +/// Copyright (c) 2019 Dimo Markov +/// Part of the Langulus framework, see https://langulus.com +/// +/// SPDX-License-Identifier: MIT +/// +#pragma once +#include "../Common.hpp" + + +/// Scalar + Scalar (either dense or sparse, wrapped or not) +template LANGULUS(INLINED) +void ControlAdd(const LHS& lhs, const RHS& rhs, OUT& out) noexcept { + auto& fout = FundamentalCast(out); + fout = SIMD::Inner::AddFallback(FundamentalCast(lhs), FundamentalCast(rhs)); +} + +/// Vector + Vector (either dense or sparse, wrapped or not) +template LANGULUS(INLINED) +void ControlAdd(const LHS& lhsArray, const RHS& rhsArray, OUT& out) noexcept { + static_assert(LHS::MemberCount == RHS::MemberCount + and LHS::MemberCount == OUT::MemberCount, + "Vector sizes must match"); + + auto r = out.mArray; + auto lhs = lhsArray.mArray; + auto rhs = rhsArray.mArray; + const auto lhsEnd = lhs + LHS::MemberCount; + while (lhs != lhsEnd) + ControlAdd(*lhs++, *rhs++, *r++); +} + +/// Scalar + Vector (either dense or sparse, wrapped or not) +template LANGULUS(INLINED) +void ControlAdd(const LHS& lhs, const RHS& rhsArray, OUT& out) noexcept { + static_assert(RHS::MemberCount == OUT::MemberCount, + "Vector sizes must match"); + + auto r = out.mArray; + auto rhs = rhsArray.mArray; + const auto rhsEnd = rhs + RHS::MemberCount; + while (rhs != rhsEnd) + ControlAdd(lhs, *rhs++, *r++); +} + +/// Vector + Scalar (either dense or sparse, wrapped or not) +template LANGULUS(INLINED) +void ControlAdd(const LHS& lhsArray, const RHS& rhs, OUT& out) noexcept { + return ControlAdd(rhs, lhsArray, out); +} \ No newline at end of file diff --git a/test/Multiply/TestMul.hpp b/test/Multiply/TestMul.hpp index 0fdf7eb..db735da 100644 --- a/test/Multiply/TestMul.hpp +++ b/test/Multiply/TestMul.hpp @@ -10,7 +10,6 @@ /// Scalar * Scalar (either dense or sparse, wrapped or not) -/// @attention 8bit integers are always multiplied with saturation template LANGULUS(INLINED) void ControlMul(const LHS& lhs, const RHS& rhs, OUT& out) noexcept { auto& fout = FundamentalCast(out); diff --git a/test/Subtract/TestSub.hpp b/test/Subtract/TestSub.hpp index 9e92f4c..d3a60de 100644 --- a/test/Subtract/TestSub.hpp +++ b/test/Subtract/TestSub.hpp @@ -10,7 +10,6 @@ /// Scalar - Scalar (either dense or sparse, wrapped or not) -/// @attention 8bit integers are always multiplied with saturation template LANGULUS(INLINED) void ControlSub(const LHS& lhs, const RHS& rhs, OUT& out) noexcept { auto& fout = FundamentalCast(out); diff --git a/test/TestAdd.cpp b/test/TestAdd.cpp deleted file mode 100644 index 23d20ef..0000000 --- a/test/TestAdd.cpp +++ /dev/null @@ -1,105 +0,0 @@ -/// -/// Langulus::SIMD -/// Copyright (c) 2019 Dimo Markov -/// Part of the Langulus framework, see https://langulus.com -/// -/// SPDX-License-Identifier: MIT -/// -#include "Common.hpp" - - -template LANGULUS(INLINED) -void ControlAdd(const LHS& lhs, const RHS& rhs, OUT& out) noexcept { - out = lhs + rhs; -} - -template LANGULUS(INLINED) -void ControlAdd(const Vector& lhsArray, const Vector& rhsArray, Vector& out) noexcept { - auto r = out.mArray; - auto lhs = lhsArray.mArray; - auto rhs = rhsArray.mArray; - const auto lhsEnd = lhs + C; - while (lhs != lhsEnd) - ControlAdd(*lhs++, *rhs++, *r++); -} - -TEMPLATE_TEST_CASE("Add", "[add]" - , NUMBERS_ALL() - , VECTORS_ALL(1) - , VECTORS_ALL(2) - , VECTORS_ALL(3) - , VECTORS_ALL(4) - , VECTORS_ALL(5) - , VECTORS_ALL(8) - , VECTORS_ALL(9) - , VECTORS_ALL(16) - , VECTORS_ALL(17) - , VECTORS_ALL(32) - , VECTORS_ALL(33) -) { - using T = TestType; - - GIVEN("x + y = r") { - T x, y; - T r, rCheck; - - if constexpr (not CT::Vector) { - InitOne(x, 1); - InitOne(y, -5); - } - - WHEN("Added") { - ControlAdd(x, y, rCheck); - SIMD::Add(x, y, r); - - REQUIRE(r == rCheck); - - #ifdef LANGULUS_STD_BENCHMARK - BENCHMARK_ADVANCED("Add (control)") (timer meter) { - some nx(meter.runs()); - if constexpr (not CT::Vector) { - for (auto& i : nx) - InitOne(i, 1); - } - - some ny(meter.runs()); - if constexpr (not CT::Vector) { - for (auto& i : ny) - InitOne(i, 1); - } - - some nr(meter.runs()); - meter.measure([&](int i) { - ControlAdd(nx[i], ny[i], nr[i]); - }); - }; - - BENCHMARK_ADVANCED("Add (SIMD)") (timer meter) { - some nx(meter.runs()); - if constexpr (not CT::Vector) { - for (auto& i : nx) - InitOne(i, 1); - } - - some ny(meter.runs()); - if constexpr (not CT::Vector) { - for (auto& i : ny) - InitOne(i, 1); - } - - some nr(meter.runs()); - meter.measure([&](int i) { - SIMD::Add(nx[i], ny[i], nr[i]); - }); - }; - #endif - } - - WHEN("Added in reverse") { - ControlAdd(y, x, rCheck); - SIMD::Add(y, x, r); - - REQUIRE(r == rCheck); - } - } -} \ No newline at end of file