Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
build:
runs-on: ${{ matrix.os }}
container: ${{ matrix.container && matrix.container || '' }}
name: ${{ matrix.name }}${{ matrix.arch && format('-{0}', matrix.arch) || '' }} build${{ matrix.arch != 'arm64-v8a' && matrix.name != 'ios-sim' && matrix.name != 'ios' && matrix.name != 'apple-xcframework' && matrix.name != 'android-aar' && ( matrix.name != 'macos' || matrix.arch != 'x86_64' ) && ' + test' || ''}}
name: ${{ matrix.name }}${{ matrix.arch && format('-{0}', matrix.arch) || '' }} build${{ matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a' && matrix.name != 'ios-sim' && matrix.name != 'ios' && matrix.name != 'apple-xcframework' && matrix.name != 'android-aar' && ( matrix.name != 'macos' || matrix.arch != 'x86_64' ) && ' + test' || ''}}
timeout-minutes: 20
strategy:
fail-fast: false
Expand Down Expand Up @@ -47,6 +47,10 @@ jobs:
arch: arm64-v8a
name: android
make: PLATFORM=android ARCH=arm64-v8a
- os: ubuntu-22.04
arch: armeabi-v7a
name: android
make: PLATFORM=android ARCH=armeabi-v7a
- os: ubuntu-22.04
arch: x86_64
name: android
Expand Down Expand Up @@ -140,7 +144,7 @@ jobs:
security delete-keychain build.keychain

- name: android setup test environment
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a'
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a'
run: |

echo "::group::enable kvm group perms"
Expand Down Expand Up @@ -168,7 +172,7 @@ jobs:
echo "::endgroup::"

- name: android test sqlite-vector
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a'
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a'
uses: reactivecircus/android-emulator-runner@v2.34.0
with:
api-level: 26
Expand Down
19 changes: 14 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,22 @@ else ifeq ($(PLATFORM),macos)
STRIP = strip -x -S $@
else ifeq ($(PLATFORM),android)
ifndef ARCH # Set ARCH to find Android NDK's Clang compiler, the user should set the ARCH
$(error "Android ARCH must be set to ARCH=x86_64 or ARCH=arm64-v8a")
$(error "Android ARCH must be set to ARCH=x86_64, ARCH=arm64-v8a, or ARCH=armeabi-v7a")
endif
ifndef ANDROID_NDK # Set ANDROID_NDK path to find android build tools; e.g. on MacOS: export ANDROID_NDK=/Users/username/Library/Android/sdk/ndk/25.2.9519653
$(error "Android NDK must be set")
endif
BIN = $(ANDROID_NDK)/toolchains/llvm/prebuilt/$(HOST)-x86_64/bin
ifneq (,$(filter $(ARCH),arm64 arm64-v8a))
override ARCH := aarch64
ANDROID_ABI := android26
else ifeq ($(ARCH),armeabi-v7a)
override ARCH := armv7a
ANDROID_ABI := androideabi26
else
ANDROID_ABI := android26
endif
CC = $(BIN)/$(ARCH)-linux-android26-clang
CC = $(BIN)/$(ARCH)-linux-$(ANDROID_ABI)-clang
TARGET := $(DIST_DIR)/vector.so
LDFLAGS += -lm -shared
STRIP = $(BIN)/llvm-strip --strip-unneeded $@
Expand Down Expand Up @@ -184,11 +190,14 @@ $(DIST_DIR)/%.xcframework: $(LIB_NAMES)

xcframework: $(DIST_DIR)/vector.xcframework

AAR_ARM = packages/android/src/main/jniLibs/arm64-v8a/
AAR_ARM64 = packages/android/src/main/jniLibs/arm64-v8a/
AAR_ARM = packages/android/src/main/jniLibs/armeabi-v7a/
AAR_X86 = packages/android/src/main/jniLibs/x86_64/
aar:
mkdir -p $(AAR_ARM) $(AAR_X86)
mkdir -p $(AAR_ARM64) $(AAR_ARM) $(AAR_X86)
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=arm64-v8a
mv $(DIST_DIR)/vector.so $(AAR_ARM64)
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=armeabi-v7a
mv $(DIST_DIR)/vector.so $(AAR_ARM)
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=x86_64
mv $(DIST_DIR)/vector.so $(AAR_X86)
Expand All @@ -208,7 +217,7 @@ help:
@echo " linux (default on Linux)"
@echo " macos (default on macOS)"
@echo " windows (default on Windows)"
@echo " android (needs ARCH to be set to x86_64 or arm64-v8a and ANDROID_NDK to be set)"
@echo " android (needs ARCH to be set to x86_64, arm64-v8a, or armeabi-v7a and ANDROID_NDK to be set)"
@echo " ios (only on macOS)"
@echo " ios-sim (only on macOS)"
@echo ""
Expand Down
128 changes: 128 additions & 0 deletions src/distance-neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,27 @@


#if defined(__ARM_NEON) || defined(__ARM_NEON__)

#if __SIZEOF_POINTER__ == 4
#define _ARM32BIT_ 1
#endif

#include <arm_neon.h>

extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX];
extern char *distance_backend_name;

// Helper function for 32-bit ARM: vmaxv_u16 is not available in ARMv7 NEON
#ifdef _ARM32BIT_
static inline uint16_t vmaxv_u16_compat(uint16x4_t v) {
// Use pairwise max to reduce vector
uint16x4_t m = vpmax_u16(v, v); // [max(v0,v1), max(v2,v3), max(v0,v1), max(v2,v3)]
m = vpmax_u16(m, m); // [max(all), max(all), max(all), max(all)]
return vget_lane_u16(m, 0);
}
#define vmaxv_u16 vmaxv_u16_compat
#endif

// MARK: FLOAT32 -

float float32_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool use_sqrt) {
Expand Down Expand Up @@ -158,6 +174,31 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
const uint16_t *a = (const uint16_t *)v1;
const uint16_t *b = (const uint16_t *)v2;

#ifdef _ARM32BIT_
// 32-bit ARM: use scalar double accumulation (no float64x2_t in NEON)
double sum = 0.0;
int i = 0;

for (; i <= n - 4; i += 4) {
uint16x4_t av16 = vld1_u16(a + i);
uint16x4_t bv16 = vld1_u16(b + i);

float32x4_t va = bf16x4_to_f32x4_u16(av16);
float32x4_t vb = bf16x4_to_f32x4_u16(bv16);
float32x4_t d = vsubq_f32(va, vb);
// mask-out NaNs: m = (d==d)
uint32x4_t m = vceqq_f32(d, d);
d = vbslq_f32(m, d, vdupq_n_f32(0.0f));

// Store and accumulate in scalar double
float tmp[4];
vst1q_f32(tmp, d);
for (int j = 0; j < 4; j++) {
double dj = (double)tmp[j];
sum = fma(dj, dj, sum);
}
}
#else
// Accumulate in f64 to avoid overflow from huge bf16 values.
float64x2_t acc0 = vdupq_n_f64(0.0), acc1 = vdupq_n_f64(0.0);
int i = 0;
Expand Down Expand Up @@ -205,6 +246,7 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
}

double sum = vaddvq_f64(vaddq_f64(acc0, acc1));
#endif

// scalar tail; treat NaN as 0, Inf as +Inf result
for (; i < n; ++i) {
Expand Down Expand Up @@ -409,8 +451,15 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
const uint16x4_t ZERO16 = vdup_n_u16(0);

#ifdef _ARM32BIT_
// 32-bit ARM: use scalar double accumulation
double sum = 0.0;
int i = 0;
#else
// 64-bit ARM: use float64x2_t NEON intrinsics
float64x2_t acc0 = vdupq_n_f64(0.0), acc1 = vdupq_n_f64(0.0);
int i = 0;
#endif

for (; i <= n - 4; i += 4) {
uint16x4_t av16 = vld1_u16(a + i);
Expand Down Expand Up @@ -443,6 +492,16 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
uint32x4_t m = vceqq_f32(d32, d32); /* true where not-NaN */
d32 = vbslq_f32(m, d32, vdupq_n_f32(0.0f));

#ifdef _ARM32BIT_
// 32-bit ARM: accumulate in scalar double
float tmp[4];
vst1q_f32(tmp, d32);
for (int j = 0; j < 4; j++) {
double dj = (double)tmp[j];
sum = fma(dj, dj, sum);
}
#else
// 64-bit ARM: use NEON f64 operations
float64x2_t dlo = vcvt_f64_f32(vget_low_f32(d32));
float64x2_t dhi = vcvt_f64_f32(vget_high_f32(d32));
#if defined(__ARM_FEATURE_FMA)
Expand All @@ -451,10 +510,13 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
#else
acc0 = vaddq_f64(acc0, vmulq_f64(dlo, dlo));
acc1 = vaddq_f64(acc1, vmulq_f64(dhi, dhi));
#endif
#endif
}

#ifndef _ARM32BIT_
double sum = vaddvq_f64(vaddq_f64(acc0, acc1));
#endif

/* tail (scalar; same Inf/NaN policy) */
for (; i < n; ++i) {
Expand Down Expand Up @@ -487,10 +549,17 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
const uint16x4_t ZERO16 = vdup_n_u16(0);

#ifdef _ARM32BIT_
// 32-bit ARM: use scalar double accumulation
double dot = 0.0, normx = 0.0, normy = 0.0;
int i = 0;
#else
// 64-bit ARM: use float64x2_t NEON intrinsics
float64x2_t acc_dot_lo = vdupq_n_f64(0.0), acc_dot_hi = vdupq_n_f64(0.0);
float64x2_t acc_a2_lo = vdupq_n_f64(0.0), acc_a2_hi = vdupq_n_f64(0.0);
float64x2_t acc_b2_lo = vdupq_n_f64(0.0), acc_b2_hi = vdupq_n_f64(0.0);
int i = 0;
#endif

for (; i <= n - 4; i += 4) {
uint16x4_t av16 = vld1_u16(a + i);
Expand All @@ -512,6 +581,19 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
ax = vbslq_f32(mx, ax, vdupq_n_f32(0.0f));
by = vbslq_f32(my, by, vdupq_n_f32(0.0f));

#ifdef _ARM32BIT_
// 32-bit ARM: accumulate in scalar double
float ax_tmp[4], by_tmp[4];
vst1q_f32(ax_tmp, ax);
vst1q_f32(by_tmp, by);
for (int j = 0; j < 4; j++) {
double x = (double)ax_tmp[j];
double y = (double)by_tmp[j];
dot += x * y;
normx += x * x;
normy += y * y;
}
#else
/* widen to f64 and accumulate */
float64x2_t ax_lo = vcvt_f64_f32(vget_low_f32(ax)), ax_hi = vcvt_f64_f32(vget_high_f32(ax));
float64x2_t by_lo = vcvt_f64_f32(vget_low_f32(by)), by_hi = vcvt_f64_f32(vget_high_f32(by));
Expand All @@ -530,12 +612,15 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
acc_a2_hi = vaddq_f64(acc_a2_hi, vmulq_f64(ax_hi, ax_hi));
acc_b2_lo = vaddq_f64(acc_b2_lo, vmulq_f64(by_lo, by_lo));
acc_b2_hi = vaddq_f64(acc_b2_hi, vmulq_f64(by_hi, by_hi));
#endif
#endif
}

#ifndef _ARM32BIT_
double dot = vaddvq_f64(vaddq_f64(acc_dot_lo, acc_dot_hi));
double normx= vaddvq_f64(vaddq_f64(acc_a2_lo, acc_a2_hi));
double normy= vaddvq_f64(vaddq_f64(acc_b2_lo, acc_b2_hi));
#endif

/* tail (scalar) */
for (; i < n; ++i) {
Expand Down Expand Up @@ -569,8 +654,15 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
const uint16x4_t ZERO16 = vdup_n_u16(0);

#ifdef _ARM32BIT_
// 32-bit ARM: use scalar double accumulation
double dot = 0.0;
int i = 0;
#else
// 64-bit ARM: use float64x2_t NEON intrinsics
float64x2_t acc_lo = vdupq_n_f64(0.0), acc_hi = vdupq_n_f64(0.0);
int i = 0;
#endif

for (; i <= n - 4; i += 4) {
uint16x4_t av16 = vld1_u16(a + i);
Expand All @@ -588,7 +680,11 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
if (isnan(x) || isnan(y)) continue;
double p = (double)x * (double)y;
if (isinf(p)) return (p>0)? -INFINITY : INFINITY;
#ifdef _ARM32BIT_
dot += p;
#else
acc_lo = vsetq_lane_f64(vgetq_lane_f64(acc_lo,0)+p, acc_lo, 0); /* cheap add */
#endif
}
continue;
}
Expand All @@ -603,13 +699,26 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
by = vbslq_f32(my, by, vdupq_n_f32(0.0f));

float32x4_t prod = vmulq_f32(ax, by);

#ifdef _ARM32BIT_
// 32-bit ARM: accumulate in scalar double
float prod_tmp[4];
vst1q_f32(prod_tmp, prod);
for (int j = 0; j < 4; j++) {
dot += (double)prod_tmp[j];
}
#else
// 64-bit ARM: use NEON f64 operations
float64x2_t lo = vcvt_f64_f32(vget_low_f32(prod));
float64x2_t hi = vcvt_f64_f32(vget_high_f32(prod));
acc_lo = vaddq_f64(acc_lo, lo);
acc_hi = vaddq_f64(acc_hi, hi);
#endif
}

#ifndef _ARM32BIT_
double dot = vaddvq_f64(vaddq_f64(acc_lo, acc_hi));
#endif

for (; i < n; ++i) {
float x = float16_to_float32(a[i]);
Expand All @@ -635,8 +744,15 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
const uint16x4_t ZERO16 = vdup_n_u16(0);

#ifdef _ARM32BIT_
// 32-bit ARM: use scalar double accumulation
double sum = 0.0;
int i = 0;
#else
// 64-bit ARM: use float64x2_t NEON intrinsics
float64x2_t acc = vdupq_n_f64(0.0);
int i = 0;
#endif

for (; i <= n - 4; i += 4) {
uint16x4_t av16 = vld1_u16(a + i);
Expand Down Expand Up @@ -665,13 +781,25 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
uint32x4_t m = vceqq_f32(d, d); /* mask NaNs -> 0 */
d = vbslq_f32(m, d, vdupq_n_f32(0.0f));

#ifdef _ARM32BIT_
// 32-bit ARM: accumulate in scalar double
float tmp[4];
vst1q_f32(tmp, d);
for (int j = 0; j < 4; j++) {
sum += (double)tmp[j];
}
#else
// 64-bit ARM: use NEON f64 operations
float64x2_t lo = vcvt_f64_f32(vget_low_f32(d));
float64x2_t hi = vcvt_f64_f32(vget_high_f32(d));
acc = vaddq_f64(acc, lo);
acc = vaddq_f64(acc, hi);
#endif
}

#ifndef _ARM32BIT_
double sum = vaddvq_f64(acc);
#endif

for (; i < n; ++i) {
uint16_t ai=a[i], bi=b[i];
Expand Down