Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 55 additions & 29 deletions include/xsimd/arch/xsimd_neon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,16 +717,10 @@ namespace xsimd
return vnegq_s32(rhs);
}

template <class A, class T, detail::enable_sized_unsigned_t<T, 8> = 0>
XSIMD_INLINE batch<T, A> neg(batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch<T, A> { -rhs.get(0), -rhs.get(1) };
}

template <class A, class T, detail::enable_sized_signed_t<T, 8> = 0>
template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
XSIMD_INLINE batch<T, A> neg(batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch<T, A> { -rhs.get(0), -rhs.get(1) };
return 0 - rhs;
}

template <class A>
Expand Down Expand Up @@ -923,16 +917,28 @@ namespace xsimd
return dispatcher.apply(register_type(lhs), register_type(rhs));
}

template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
template <class A, class T, detail::enable_sized_unsigned_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> eq(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
auto eq32 = vceqq_u32(vreinterpretq_u32_u64(lhs.data), vreinterpretq_u32_u64(rhs.data));
auto rev32 = vrev64q_u32(eq32);
auto eq64 = vandq_u32(eq32, rev32);
return batch_bool<T, A>(vreinterpretq_u64_u32(eq64));
}

template <class A, class T, detail::enable_sized_signed_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> eq(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch_bool<T, A>({ lhs.get(0) == rhs.get(0), lhs.get(1) == rhs.get(1) });
auto eq32 = vceqq_u32(vreinterpretq_u32_s64(lhs.data), vreinterpretq_u32_s64(rhs.data));
auto rev32 = vrev64q_u32(eq32);
auto eq64 = vandq_u32(eq32, rev32);
return batch_bool<T, A>(vreinterpretq_u64_u32(eq64));
}

template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> eq(batch_bool<T, A> const& lhs, batch_bool<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch_bool<T, A>({ lhs.get(0) == rhs.get(0), lhs.get(1) == rhs.get(1) });
return eq(batch<T, A> { lhs.data }, batch<T, A> { rhs.data }, A {});
}

/*************
Expand Down Expand Up @@ -985,10 +991,19 @@ namespace xsimd
return dispatcher.apply(register_type(lhs), register_type(rhs));
}

template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
template <class A, class T, detail::enable_sized_signed_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> lt(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch_bool<T, A>({ lhs.get(0) < rhs.get(0), lhs.get(1) < rhs.get(1) });
using register_type = typename batch<T, A>::register_type;
return batch_bool<T, A>(vreinterpretq_u64_s64(vshrq_n_s64(vqsubq_s64(register_type(lhs), register_type(rhs)), 63)));
}

template <class A, class T, detail::enable_sized_unsigned_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> lt(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
using register_type = typename batch<T, A>::register_type;
register_type acc = { 0x7FFFFFFFFFFFFFFFull, 0x7FFFFFFFFFFFFFFFull };
return batch_bool<T, A>(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_u64(vqaddq_u64(vqsubq_u64(register_type(rhs), register_type(lhs)), acc)), 63)));
}

/******
Expand All @@ -1012,12 +1027,24 @@ namespace xsimd
template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> le(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch_bool<T, A>({ lhs.get(0) <= rhs.get(0), lhs.get(1) <= rhs.get(1) });
return !(lhs > rhs);
}

/******
* gt *
******/
namespace detail
{
XSIMD_INLINE int64x2_t bitwise_not_s64(int64x2_t arg) noexcept
{
return vreinterpretq_s64_s32(vmvnq_s32(vreinterpretq_s32_s64(arg)));
}

XSIMD_INLINE uint64x2_t bitwise_not_u64(uint64x2_t arg) noexcept
{
return vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(arg)));
}
}

WRAP_BINARY_INT_EXCLUDING_64(vcgtq, detail::comp_return_type)
WRAP_BINARY_FLOAT(vcgtq, detail::comp_return_type)
Expand All @@ -1033,10 +1060,19 @@ namespace xsimd
return dispatcher.apply(register_type(lhs), register_type(rhs));
}

template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
template <class A, class T, detail::enable_sized_signed_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> gt(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch_bool<T, A>({ lhs.get(0) > rhs.get(0), lhs.get(1) > rhs.get(1) });
using register_type = typename batch<T, A>::register_type;
return batch_bool<T, A>(vreinterpretq_u64_s64(vshrq_n_s64(vqsubq_s64(register_type(rhs), register_type(lhs)), 63)));
}

template <class A, class T, detail::enable_sized_unsigned_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> gt(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
using register_type = typename batch<T, A>::register_type;
register_type acc = { 0x7FFFFFFFFFFFFFFFull, 0x7FFFFFFFFFFFFFFFull };
return batch_bool<T, A>(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_u64(vqaddq_u64(vqsubq_u64(register_type(lhs), register_type(rhs)), acc)), 63)));
}

/******
Expand All @@ -1060,7 +1096,7 @@ namespace xsimd
template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
XSIMD_INLINE batch_bool<T, A> ge(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return batch_bool<T, A>({ lhs.get(0) >= rhs.get(0), lhs.get(1) >= rhs.get(1) });
return !(lhs < rhs);
}

/*******************
Expand Down Expand Up @@ -1212,16 +1248,6 @@ namespace xsimd

namespace detail
{
XSIMD_INLINE int64x2_t bitwise_not_s64(int64x2_t arg) noexcept
{
return vreinterpretq_s64_s32(vmvnq_s32(vreinterpretq_s32_s64(arg)));
}

XSIMD_INLINE uint64x2_t bitwise_not_u64(uint64x2_t arg) noexcept
{
return vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(arg)));
}

XSIMD_INLINE float32x4_t bitwise_not_f32(float32x4_t arg) noexcept
{
return vreinterpretq_f32_u32(vmvnq_u32(vreinterpretq_u32_f32(arg)));
Expand Down Expand Up @@ -1314,7 +1340,7 @@ namespace xsimd
template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
XSIMD_INLINE batch<T, A> min(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return { std::min(lhs.get(0), rhs.get(0)), std::min(lhs.get(1), rhs.get(1)) };
return select(lhs > rhs, rhs, lhs);
}

/*******
Expand All @@ -1338,7 +1364,7 @@ namespace xsimd
template <class A, class T, detail::enable_sized_integral_t<T, 8> = 0>
XSIMD_INLINE batch<T, A> max(batch<T, A> const& lhs, batch<T, A> const& rhs, requires_arch<neon>) noexcept
{
return { std::max(lhs.get(0), rhs.get(0)), std::max(lhs.get(1), rhs.get(1)) };
return select(lhs > rhs, lhs, rhs);
}

/*******
Expand Down
20 changes: 20 additions & 0 deletions test/test_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ struct batch_test
auto res = batch_lhs() < batch_rhs();
INFO("batch < batch");
CHECK_BATCH_EQ(res, expected);

std::fill(expected.begin(), expected.end(), false);
res = batch_lhs() < batch_lhs();
INFO("batch < (self)");
CHECK_BATCH_EQ(res, expected);
}
// batch < scalar
{
Expand All @@ -502,6 +507,11 @@ struct batch_test
auto res = batch_lhs() <= batch_rhs();
INFO("batch <= batch");
CHECK_BATCH_EQ(res, expected);

std::fill(expected.begin(), expected.end(), true);
res = batch_lhs() <= batch_lhs();
INFO("batch < (self)");
CHECK_BATCH_EQ(res, expected);
}
// batch <= scalar
{
Expand All @@ -527,6 +537,11 @@ struct batch_test
auto res = batch_lhs() > batch_rhs();
INFO("batch > batch");
CHECK_BATCH_EQ(res, expected);

std::fill(expected.begin(), expected.end(), false);
res = batch_lhs() > batch_lhs();
INFO("batch > (self)");
CHECK_BATCH_EQ(res, expected);
}
// batch > scalar
{
Expand All @@ -551,6 +566,11 @@ struct batch_test
auto res = batch_lhs() >= batch_rhs();
INFO("batch >= batch");
CHECK_BATCH_EQ(res, expected);

std::fill(expected.begin(), expected.end(), true);
res = batch_lhs() >= batch_lhs();
INFO("batch >= (self)");
CHECK_BATCH_EQ(res, expected);
}
// batch >= scalar
{
Expand Down