diff --git a/include/xsimd/arch/xsimd_neon.hpp b/include/xsimd/arch/xsimd_neon.hpp index dfbc235ff..e8031e6e6 100644 --- a/include/xsimd/arch/xsimd_neon.hpp +++ b/include/xsimd/arch/xsimd_neon.hpp @@ -717,16 +717,10 @@ namespace xsimd return vnegq_s32(rhs); } - template = 0> - XSIMD_INLINE batch neg(batch const& rhs, requires_arch) noexcept - { - return batch { -rhs.get(0), -rhs.get(1) }; - } - - template = 0> + template = 0> XSIMD_INLINE batch neg(batch const& rhs, requires_arch) noexcept { - return batch { -rhs.get(0), -rhs.get(1) }; + return 0 - rhs; } template @@ -923,16 +917,28 @@ namespace xsimd return dispatcher.apply(register_type(lhs), register_type(rhs)); } - template = 0> + template = 0> + XSIMD_INLINE batch_bool eq(batch const& lhs, batch const& rhs, requires_arch) noexcept + { + auto eq32 = vceqq_u32(vreinterpretq_u32_u64(lhs.data), vreinterpretq_u32_u64(rhs.data)); + auto rev32 = vrev64q_u32(eq32); + auto eq64 = vandq_u32(eq32, rev32); + return batch_bool(vreinterpretq_u64_u32(eq64)); + } + + template = 0> XSIMD_INLINE batch_bool eq(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return batch_bool({ lhs.get(0) == rhs.get(0), lhs.get(1) == rhs.get(1) }); + auto eq32 = vceqq_u32(vreinterpretq_u32_s64(lhs.data), vreinterpretq_u32_s64(rhs.data)); + auto rev32 = vrev64q_u32(eq32); + auto eq64 = vandq_u32(eq32, rev32); + return batch_bool(vreinterpretq_u64_u32(eq64)); } template = 0> XSIMD_INLINE batch_bool eq(batch_bool const& lhs, batch_bool const& rhs, requires_arch) noexcept { - return batch_bool({ lhs.get(0) == rhs.get(0), lhs.get(1) == rhs.get(1) }); + return eq(batch { lhs.data }, batch { rhs.data }, A {}); } /************* @@ -985,10 +991,19 @@ namespace xsimd return dispatcher.apply(register_type(lhs), register_type(rhs)); } - template = 0> + template = 0> XSIMD_INLINE batch_bool lt(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return batch_bool({ lhs.get(0) < rhs.get(0), lhs.get(1) < rhs.get(1) }); + using register_type = typename batch::register_type; + return batch_bool(vreinterpretq_u64_s64(vshrq_n_s64(vqsubq_s64(register_type(lhs), register_type(rhs)), 63))); + } + + template = 0> + XSIMD_INLINE batch_bool lt(batch const& lhs, batch const& rhs, requires_arch) noexcept + { + using register_type = typename batch::register_type; + register_type acc = { 0x7FFFFFFFFFFFFFFFull, 0x7FFFFFFFFFFFFFFFull }; + return batch_bool(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_u64(vqaddq_u64(vqsubq_u64(register_type(rhs), register_type(lhs)), acc)), 63))); } /****** @@ -1012,12 +1027,24 @@ namespace xsimd template = 0> XSIMD_INLINE batch_bool le(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return batch_bool({ lhs.get(0) <= rhs.get(0), lhs.get(1) <= rhs.get(1) }); + return !(lhs > rhs); } /****** * gt * ******/ + namespace detail + { + XSIMD_INLINE int64x2_t bitwise_not_s64(int64x2_t arg) noexcept + { + return vreinterpretq_s64_s32(vmvnq_s32(vreinterpretq_s32_s64(arg))); + } + + XSIMD_INLINE uint64x2_t bitwise_not_u64(uint64x2_t arg) noexcept + { + return vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(arg))); + } + } WRAP_BINARY_INT_EXCLUDING_64(vcgtq, detail::comp_return_type) WRAP_BINARY_FLOAT(vcgtq, detail::comp_return_type) @@ -1033,10 +1060,19 @@ namespace xsimd return dispatcher.apply(register_type(lhs), register_type(rhs)); } - template = 0> + template = 0> XSIMD_INLINE batch_bool gt(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return batch_bool({ lhs.get(0) > rhs.get(0), lhs.get(1) > rhs.get(1) }); + using register_type = typename batch::register_type; + return batch_bool(vreinterpretq_u64_s64(vshrq_n_s64(vqsubq_s64(register_type(rhs), register_type(lhs)), 63))); + } + + template = 0> + XSIMD_INLINE batch_bool gt(batch const& lhs, batch const& rhs, requires_arch) noexcept + { + using register_type = typename batch::register_type; + register_type acc = { 0x7FFFFFFFFFFFFFFFull, 0x7FFFFFFFFFFFFFFFull }; + return batch_bool(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_u64(vqaddq_u64(vqsubq_u64(register_type(lhs), register_type(rhs)), acc)), 63))); } /****** @@ -1060,7 +1096,7 @@ namespace xsimd template = 0> XSIMD_INLINE batch_bool ge(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return batch_bool({ lhs.get(0) >= rhs.get(0), lhs.get(1) >= rhs.get(1) }); + return !(lhs < rhs); } /******************* @@ -1212,16 +1248,6 @@ namespace xsimd namespace detail { - XSIMD_INLINE int64x2_t bitwise_not_s64(int64x2_t arg) noexcept - { - return vreinterpretq_s64_s32(vmvnq_s32(vreinterpretq_s32_s64(arg))); - } - - XSIMD_INLINE uint64x2_t bitwise_not_u64(uint64x2_t arg) noexcept - { - return vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(arg))); - } - XSIMD_INLINE float32x4_t bitwise_not_f32(float32x4_t arg) noexcept { return vreinterpretq_f32_u32(vmvnq_u32(vreinterpretq_u32_f32(arg))); @@ -1314,7 +1340,7 @@ namespace xsimd template = 0> XSIMD_INLINE batch min(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return { std::min(lhs.get(0), rhs.get(0)), std::min(lhs.get(1), rhs.get(1)) }; + return select(lhs > rhs, rhs, lhs); } /******* @@ -1338,7 +1364,7 @@ namespace xsimd template = 0> XSIMD_INLINE batch max(batch const& lhs, batch const& rhs, requires_arch) noexcept { - return { std::max(lhs.get(0), rhs.get(0)), std::max(lhs.get(1), rhs.get(1)) }; + return select(lhs > rhs, lhs, rhs); } /******* diff --git a/test/test_batch.cpp b/test/test_batch.cpp index ff043dc8b..d23b53729 100644 --- a/test/test_batch.cpp +++ b/test/test_batch.cpp @@ -477,6 +477,11 @@ struct batch_test auto res = batch_lhs() < batch_rhs(); INFO("batch < batch"); CHECK_BATCH_EQ(res, expected); + + std::fill(expected.begin(), expected.end(), false); + res = batch_lhs() < batch_lhs(); + INFO("batch < (self)"); + CHECK_BATCH_EQ(res, expected); } // batch < scalar { @@ -502,6 +507,11 @@ struct batch_test auto res = batch_lhs() <= batch_rhs(); INFO("batch <= batch"); CHECK_BATCH_EQ(res, expected); + + std::fill(expected.begin(), expected.end(), true); + res = batch_lhs() <= batch_lhs(); + INFO("batch < (self)"); + CHECK_BATCH_EQ(res, expected); } // batch <= scalar { @@ -527,6 +537,11 @@ struct batch_test auto res = batch_lhs() > batch_rhs(); INFO("batch > batch"); CHECK_BATCH_EQ(res, expected); + + std::fill(expected.begin(), expected.end(), false); + res = batch_lhs() > batch_lhs(); + INFO("batch > (self)"); + CHECK_BATCH_EQ(res, expected); } // batch > scalar { @@ -551,6 +566,11 @@ struct batch_test auto res = batch_lhs() >= batch_rhs(); INFO("batch >= batch"); CHECK_BATCH_EQ(res, expected); + + std::fill(expected.begin(), expected.end(), true); + res = batch_lhs() >= batch_lhs(); + INFO("batch >= (self)"); + CHECK_BATCH_EQ(res, expected); } // batch >= scalar {