diff --git a/.github/actions/build-diskann-native/action.yml b/.github/actions/build-diskann-native/action.yml new file mode 100644 index 000000000000..347cfb82c769 --- /dev/null +++ b/.github/actions/build-diskann-native/action.yml @@ -0,0 +1,63 @@ +name: 'Build DiskANN Native Library' +description: 'Build DiskANN native library (Rust JNI) for specified platform' +inputs: + platform: + description: 'Target platform (linux-amd64, linux-aarch64, darwin-aarch64)' + required: true + rust-version: + description: 'Rust toolchain version' + required: false + default: 'stable' + +runs: + using: 'composite' + steps: + - name: Install Rust toolchain + shell: bash + run: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain ${{ inputs.rust-version }} + echo "$HOME/.cargo/bin" >> $GITHUB_PATH + source "$HOME/.cargo/env" + rustc --version + cargo --version + + - name: Install dependency bundling tools (Linux) + if: startsWith(inputs.platform, 'linux') + shell: bash + run: | + sudo apt-get update -qq && sudo apt-get install -y -qq patchelf + + - name: Build native library + shell: bash + run: | + ./paimon-diskann/paimon-diskann-jni/scripts/build-native.sh --clean + + - name: List built libraries (Linux AMD64) + if: inputs.platform == 'linux-amd64' + shell: bash + run: | + echo "=== Built libraries ===" + ls -la paimon-diskann/paimon-diskann-jni/src/main/resources/linux/amd64/ + echo "" + echo "=== Library dependencies ===" + ldd paimon-diskann/paimon-diskann-jni/src/main/resources/linux/amd64/libpaimon_diskann_jni.so || true + + - name: List built libraries (Linux AARCH64) + if: inputs.platform == 'linux-aarch64' + shell: bash + run: | + echo "=== Built libraries ===" + ls -la paimon-diskann/paimon-diskann-jni/src/main/resources/linux/aarch64/ + echo "" + echo "=== Library dependencies ===" + ldd paimon-diskann/paimon-diskann-jni/src/main/resources/linux/aarch64/libpaimon_diskann_jni.so || true + + - name: List built libraries (macOS) + if: inputs.platform == 'darwin-aarch64' + shell: bash + run: | + echo "=== Built libraries ===" + ls -la paimon-diskann/paimon-diskann-jni/src/main/resources/darwin/aarch64/ + echo "" + echo "=== Library dependencies ===" + otool -L paimon-diskann/paimon-diskann-jni/src/main/resources/darwin/aarch64/libpaimon_diskann_jni.dylib || true diff --git a/.github/workflows/build-diskann-native.yml b/.github/workflows/build-diskann-native.yml new file mode 100644 index 000000000000..f4e7ae00b540 --- /dev/null +++ b/.github/workflows/build-diskann-native.yml @@ -0,0 +1,68 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +name: Build DiskANN Native Library + +on: + workflow_call: + inputs: + platform: + description: 'Target platform (linux-amd64, linux-aarch64, darwin-aarch64)' + required: false + type: string + default: 'linux-amd64' + jdk-version: + description: 'JDK version to use' + required: false + type: string + default: '8' + artifact-name: + description: 'Name for the uploaded artifact' + required: false + type: string + default: 'diskann-native-linux-amd64' + retention-days: + description: 'Number of days to retain the artifact' + required: false + type: number + default: 1 + +jobs: + build_native: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up JDK ${{ inputs.jdk-version }} + uses: actions/setup-java@v5 + with: + java-version: ${{ inputs.jdk-version }} + distribution: 'temurin' + + - name: Build DiskANN native library + uses: ./.github/actions/build-diskann-native + with: + platform: ${{ inputs.platform }} + + - name: Upload native library artifact + uses: actions/upload-artifact@v6 + with: + name: ${{ inputs.artifact-name }} + path: paimon-diskann/paimon-diskann-jni/src/main/resources/ + retention-days: ${{ inputs.retention-days }} diff --git a/.github/workflows/publish-diskann_snapshot.yml b/.github/workflows/publish-diskann_snapshot.yml new file mode 100644 index 000000000000..93384ea44eed --- /dev/null +++ b/.github/workflows/publish-diskann_snapshot.yml @@ -0,0 +1,184 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +name: Publish DiskANN Snapshot + +on: + schedule: + # At the end of every day + - cron: '0 1 * * *' + workflow_dispatch: + push: + paths: + - 'paimon-diskann/**' + branches: + - master + +env: + JDK_VERSION: 8 + MAVEN_OPTS: -Dmaven.wagon.httpconnectionManager.ttlSeconds=30 -Dmaven.wagon.http.retryHandler.requestSentEnabled=true + +concurrency: + group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event.number || github.run_id }} + cancel-in-progress: true + +jobs: + # Build native library for Linux AMD64 + build-native-linux-amd64: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up JDK ${{ env.JDK_VERSION }} + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JDK_VERSION }} + distribution: 'temurin' + + - name: Build DiskANN native library + uses: ./.github/actions/build-diskann-native + with: + platform: linux-amd64 + + - name: Upload native library + uses: actions/upload-artifact@v6 + with: + name: native-linux-amd64 + path: paimon-diskann/paimon-diskann-jni/src/main/resources/linux/amd64/ + retention-days: 1 + + # Build native library for Linux AARCH64 + build-native-linux-aarch64: + runs-on: ubuntu-24.04-arm + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up JDK ${{ env.JDK_VERSION }} + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JDK_VERSION }} + distribution: 'temurin' + + - name: Build DiskANN native library + uses: ./.github/actions/build-diskann-native + with: + platform: linux-aarch64 + + - name: Upload native library + uses: actions/upload-artifact@v6 + with: + name: native-linux-aarch64 + path: paimon-diskann/paimon-diskann-jni/src/main/resources/linux/aarch64/ + retention-days: 1 + + # Build native library for macOS ARM (Apple Silicon) + build-native-macos-arm: + runs-on: macos-14 + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up JDK ${{ env.JDK_VERSION }} + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JDK_VERSION }} + distribution: 'zulu' + + - name: Build DiskANN native library + uses: ./.github/actions/build-diskann-native + with: + platform: darwin-aarch64 + + - name: Upload native library + uses: actions/upload-artifact@v6 + with: + name: native-darwin-aarch64 + path: paimon-diskann/paimon-diskann-jni/src/main/resources/darwin/aarch64/ + retention-days: 1 + + # Package and publish + package-and-publish: + if: github.repository == 'apache/paimon' + needs: [build-native-linux-amd64, build-native-linux-aarch64, build-native-macos-arm] + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up JDK ${{ env.JDK_VERSION }} + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JDK_VERSION }} + distribution: 'zulu' + + - name: Download Linux AMD64 native library + uses: actions/download-artifact@v7 + with: + name: native-linux-amd64 + path: paimon-diskann/paimon-diskann-jni/src/main/resources/linux/amd64/ + + - name: Download Linux AARCH64 native library + uses: actions/download-artifact@v7 + with: + name: native-linux-aarch64 + path: paimon-diskann/paimon-diskann-jni/src/main/resources/linux/aarch64/ + + - name: Download macOS ARM native library + uses: actions/download-artifact@v7 + with: + name: native-darwin-aarch64 + path: paimon-diskann/paimon-diskann-jni/src/main/resources/darwin/aarch64/ + + - name: List all native libraries + run: | + echo "=== All native libraries ===" + find paimon-diskann/paimon-diskann-jni/src/main/resources -type f \( -name "*.so" -o -name "*.so.*" -o -name "*.dylib" \) -exec ls -la {} \; + + - name: Cache local Maven repository + uses: actions/cache@v5 + with: + path: ~/.m2/repository + key: diskann-snapshot-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + diskann-snapshot-maven- + + - name: Build and package paimon-diskann-jni + run: | + mvn -B -ntp clean install -pl paimon-diskann/paimon-diskann-jni -am -DskipTests -Ppaimon-diskann -Drat.skip + + - name: Build and package paimon-diskann-index + run: | + mvn -B -ntp clean install -pl paimon-diskann/paimon-diskann-index -am -DskipTests -Ppaimon-diskann -Drat.skip + + - name: Publish snapshot + env: + ASF_USERNAME: ${{ secrets.NEXUS_USER }} + ASF_PASSWORD: ${{ secrets.NEXUS_PW }} + MAVEN_OPTS: -Xmx4096m + run: | + tmp_settings="tmp-settings.xml" + echo "" > $tmp_settings + echo "apache.snapshots.https$ASF_USERNAME" >> $tmp_settings + echo "$ASF_PASSWORD" >> $tmp_settings + echo "" >> $tmp_settings + + mvn --settings $tmp_settings -ntp deploy -pl paimon-diskann/paimon-diskann-jni,paimon-diskann/paimon-diskann-index -Dgpg.skip -Drat.skip -DskipTests -Ppaimon-diskann + + rm $tmp_settings diff --git a/.github/workflows/utitcase.yml b/.github/workflows/utitcase.yml index 915ec0385ac9..b64c169cf7b6 100644 --- a/.github/workflows/utitcase.yml +++ b/.github/workflows/utitcase.yml @@ -29,6 +29,8 @@ on: - 'paimon-lucene/**' - 'paimon-faiss/**' - '.github/workflows/faiss-vector-index-tests.yml' + - 'paimon-diskann/**' + - '.github/workflows/publish-diskann_snapshot.yml' - 'paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java' env: @@ -40,15 +42,21 @@ concurrency: cancel-in-progress: true jobs: - build_native: + build_faiss_native: uses: ./.github/workflows/build-faiss-native.yml with: platform: linux-amd64 jdk-version: '8' + build_diskann_native: + uses: ./.github/workflows/build-diskann-native.yml + with: + platform: linux-amd64 + jdk-version: '8' + build_test: runs-on: ubuntu-latest - needs: build_native + needs: [build_faiss_native, build_diskann_native] steps: - name: Checkout code @@ -60,21 +68,30 @@ jobs: java-version: ${{ env.JDK_VERSION }} distribution: 'temurin' - - name: Download native library artifact + - name: Download FAISS native library artifact uses: actions/download-artifact@v7 with: name: faiss-native-linux-amd64 path: paimon-faiss/paimon-faiss-jni/src/main/resources/linux/amd64/ - - name: List downloaded native library + - name: Download DiskANN native library artifact + uses: actions/download-artifact@v7 + with: + name: diskann-native-linux-amd64 + path: paimon-diskann/paimon-diskann-jni/src/main/resources/ + + - name: List downloaded native libraries run: | - echo "=== Downloaded native libraries ===" + echo "=== FAISS native libraries ===" ls -la paimon-faiss/paimon-faiss-jni/src/main/resources/linux/amd64/ + echo "" + echo "=== DiskANN native libraries ===" + find paimon-diskann/paimon-diskann-jni/src/main/resources -type f -exec ls -la {} \; - name: Build Others run: | echo "Start compiling modules" - mvn -T 2C -B -ntp clean install -DskipTests -Pflink1,spark3,paimon-faiss + mvn -T 2C -B -ntp clean install -DskipTests -Pflink1,spark3,paimon-faiss,paimon-diskann - name: Test Others timeout-minutes: 60 diff --git a/.gitignore b/.gitignore index 3f42fdc44a97..400cb9104d7e 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,5 @@ paimon-faiss/paimon-faiss-jni/build/ paimon-faiss/paimon-faiss-jni/src/main/resources/darwin* paimon-faiss/paimon-faiss-jni/src/main/resources/linux* paimon-faiss/paimon-faiss-jni/src/main/native/cmake-build-debug/ +paimon-diskann/paimon-diskann-jni/src/main/resources/darwin* +paimon-diskann/paimon-diskann-jni/src/main/resources/linux* diff --git a/paimon-diskann/PARAMETER_TUNING.md b/paimon-diskann/PARAMETER_TUNING.md new file mode 100644 index 000000000000..056f1058ec99 --- /dev/null +++ b/paimon-diskann/PARAMETER_TUNING.md @@ -0,0 +1,172 @@ + + +# DiskANN Parameter Tuning Guide + +This document provides guidance on tuning DiskANN vector index parameters for optimal performance in Apache Paimon. + +## Overview + +DiskANN is a graph-based approximate nearest neighbor (ANN) search algorithm designed for efficient billion-point vector search. The implementation in Paimon provides several parameters to control the trade-offs between accuracy, speed, and resource usage. + +## Key Parameters + +### 1. Graph Construction Parameters + +#### `vector.diskann.max-degree` (R) +- **Default**: 64 +- **Range**: 32-128 +- **Description**: Maximum degree (number of connections) for each node in the graph +- **Impact**: + - Higher values → Better recall, higher memory usage, longer build time + - Lower values → Faster build, lower memory, potentially lower recall +- **Recommendations**: + - **32**: For memory-constrained environments or when build time is critical + - **64**: Balanced default (Microsoft recommended) + - **128**: For maximum recall when resources permit + +#### `vector.diskann.build-list-size` (L) +- **Default**: 100 +- **Range**: 50-200 +- **Description**: Size of the candidate list during graph construction +- **Impact**: + - Higher values → Better graph quality, longer build time + - Lower values → Faster build, potentially lower recall +- **Recommendations**: + - Use default 100 for most cases + - Increase to 150-200 for very high-dimensional data (>512 dimensions) + +### 2. Search Parameters + +#### `vector.diskann.search-list-size` (L) +- **Default**: 100 +- **Range**: 16-500 +- **Description**: Size of the candidate list during search +- **Impact**: + - Higher values → Better recall, higher latency + - Lower values → Lower latency, potentially lower recall +- **Dynamic Behavior**: The implementation automatically adjusts this to be at least equal to the requested `k` (number of results) +- **Recommendations**: + - **16-32**: For latency-critical applications (QPS > 5000) + - **100**: Balanced default + - **200-500**: For maximum recall (recall > 95%) + +#### `vector.search-factor` +- **Default**: 10 +- **Range**: 5-20 +- **Description**: Multiplier for search limit when row filtering is applied +- **Impact**: When filtering by row IDs, fetches `limit * search-factor` results to ensure sufficient matches after filtering +- **Recommendations**: + - **5**: When filtering is selective (<10% of data) + - **10**: Default for typical filtering scenarios + - **20**: When filtering is very broad (>50% of data) + +### 3. Data Configuration + +#### `vector.dim` +- **Default**: 128 +- **Description**: Dimension of the vectors +- **Recommendations**: + - Must match your embedding model + - Common values: 128, 256, 384, 512, 768, 1024 + +#### `vector.metric` +- **Default**: L2 +- **Options**: L2, INNER_PRODUCT, COSINE +- **Description**: Distance metric for similarity computation +- **Recommendations**: + - **L2**: For Euclidean distance (most common) + - **INNER_PRODUCT**: For dot product similarity + - **COSINE**: For cosine similarity + +### 4. Index Organization + +#### `vector.size-per-index` +- **Default**: 2,000,000 +- **Description**: Number of vectors per index file +- **Impact**: + - Larger values → Fewer files, higher memory per index, better search efficiency + - Smaller values → More files, lower memory per index, more overhead +- **Recommendations**: + - **500,000**: For small datasets or memory-constrained environments + - **2,000,000**: Default for balanced performance + - **5,000,000+**: For large-scale production systems with ample resources + +## Performance Tuning Guide + +### High Recall (>95%) +```properties +vector.diskann.max-degree = 128 +vector.diskann.build-list-size = 150 +vector.diskann.search-list-size = 200 +``` + +### Balanced (90-95% recall) +```properties +vector.diskann.max-degree = 64 +vector.diskann.build-list-size = 100 +vector.diskann.search-list-size = 100 +``` + +### High QPS (Low Latency) +```properties +vector.diskann.max-degree = 32 +vector.diskann.build-list-size = 75 +vector.diskann.search-list-size = 32 +``` + +## Best Practices + +1. **Start with defaults**: The default parameters are tuned for balanced performance +2. **Measure first**: Profile your workload before tuning +3. **Tune incrementally**: Change one parameter at a time and measure impact +4. **Consider trade-offs**: Higher recall typically means higher latency and resource usage +5. **Test with production data**: Parameter effectiveness depends on data characteristics + +## Advanced Parameters (Future Enhancement) + +The following parameters are documented in the official Microsoft DiskANN implementation but are not yet exposed in the current Rust-based native library: + +- **alpha** (default: 1.2): Controls the graph construction pruning strategy +- **saturate_graph** (default: true): Whether to saturate the graph during construction + +These parameters may be added in future versions when the underlying Rust DiskANN crate exposes them through its configuration API. + +## Performance Metrics + +When tuning parameters, monitor these metrics: +- **Recall**: Percentage of true nearest neighbors found +- **QPS (Queries Per Second)**: Throughput of search operations +- **Latency**: Time to complete a single query (p50, p95, p99) +- **Memory Usage**: RAM consumed by indices +- **Build Time**: Time to construct the index + +## Recent Improvements + +### Dynamic Search List Sizing (v1.0+) +The search list size is now automatically adjusted to be at least equal to the requested `k`. This follows Milvus best practices and ensures optimal recall without manual tuning. + +### Memory-Efficient Loading (v1.0+) +Indices are now loaded through temporary files, allowing the OS to manage memory more efficiently for large indices. This is a step toward full mmap support. + +## References + +- [Microsoft DiskANN Paper](https://proceedings.neurips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf) +- [Microsoft DiskANN Library](https://github.com/microsoft/DiskANN) +- [Milvus DiskANN Documentation](https://milvus.io/docs/diskann.md) diff --git a/paimon-diskann/paimon-diskann-e2e-test/pom.xml b/paimon-diskann/paimon-diskann-e2e-test/pom.xml new file mode 100644 index 000000000000..d74a68a7de39 --- /dev/null +++ b/paimon-diskann/paimon-diskann-e2e-test/pom.xml @@ -0,0 +1,293 @@ + + + + 4.0.0 + + + org.apache.paimon + paimon-diskann + 1.4-SNAPSHOT + + + paimon-diskann-e2e-test + Paimon : DiskANN End to End Tests + + + java8 + 3.5 + 3.5.8 + + + + + + + + com.fasterxml.jackson.core + jackson-core + 2.15.2 + test + + + + com.fasterxml.jackson.core + jackson-databind + 2.15.2 + test + + + + com.fasterxml.jackson.core + jackson-annotations + 2.15.2 + test + + + + org.apache.paimon + paimon-format + ${project.version} + test + + + + org.apache.paimon + paimon-spark-${test.spark.main.version}_${scala.binary.version} + ${project.version} + test + + + + org.apache.paimon + paimon-diskann-index + ${project.version} + test + + + + org.apache.paimon + paimon-test-utils + ${project.version} + test + + + + org.apache.paimon + paimon-spark-ut_${scala.binary.version} + ${project.version} + test + test-jar + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${test.spark.version} + test + + + org.apache.logging.log4j + log4j-slf4j2-impl + + + org.slf4j + jul-to-slf4j + + + org.slf4j + jcl-over-slf4j + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${test.spark.version} + tests + test + + + org.apache.logging.log4j + log4j-slf4j2-impl + + + org.slf4j + jul-to-slf4j + + + org.slf4j + jcl-over-slf4j + + + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${test.spark.version} + tests + test + + + + org.apache.spark + spark-core_${scala.binary.version} + ${test.spark.version} + test + + + org.apache.logging.log4j + log4j-slf4j2-impl + + + org.slf4j + jul-to-slf4j + + + org.slf4j + jcl-over-slf4j + + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${test.spark.version} + tests + test + + + org.apache.logging.log4j + log4j-slf4j2-impl + + + org.slf4j + jul-to-slf4j + + + org.slf4j + jcl-over-slf4j + + + + + + org.apache.spark + spark-hive_${scala.binary.version} + ${test.spark.version} + test + + + + + org.scala-lang + scala-library + ${scala.version} + test + + + + + org.scalatest + scalatest_${scala.binary.version} + 3.2.14 + test + + + + + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + net.alchim31.maven + scala-maven-plugin + ${scala-maven-plugin.version} + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + ${scala.version} + + -nobootcp + + + + + + org.scalatest + scalatest-maven-plugin + ${scalatest-maven-plugin.version} + + ${project.build.directory}/surefire-reports + . + TestSuiteReport.txt + -ea -Xmx4g -Xss4m -XX:MaxMetaspaceSize=2g -XX:ReservedCodeCacheSize=${CodeCacheSize} ${extraJavaTestArgs} -Dio.netty.tryReflectionSetAccessible=true + org.apache.paimon.spark.sql.DiskAnnVectorIndexE2ETest + + + + test + + test + + + + + + + + + + java11 + + [11,) + + + java11 + + + + diff --git a/paimon-diskann/paimon-diskann-e2e-test/src/test/scala/org/apache/paimon/spark/sql/DiskAnnVectorIndexE2ETest.scala b/paimon-diskann/paimon-diskann-e2e-test/src/test/scala/org/apache/paimon/spark/sql/DiskAnnVectorIndexE2ETest.scala new file mode 100644 index 000000000000..e75ed462323e --- /dev/null +++ b/paimon-diskann/paimon-diskann-e2e-test/src/test/scala/org/apache/paimon/spark/sql/DiskAnnVectorIndexE2ETest.scala @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase + +import scala.collection.JavaConverters._ + +/** End-to-end tests for DiskANN vector index read/write operations on Spark 3.5. */ +class DiskAnnVectorIndexE2ETest extends PaimonSparkTestBase { + + // ========== Index Creation Tests ========== + + test("create diskann vector index - basic") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 100) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == "diskann-vector-ann") + + assert(indexEntries.nonEmpty) + val totalRowCount = indexEntries.map(_.indexFile().rowCount()).sum + assert(totalRowCount == 100L) + } + } + + test("create diskann vector index - with different index types") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 50) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == "diskann-vector-ann") + + assert(indexEntries.nonEmpty) + } + } + + test("create diskann vector index - with partitioned table") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY, pt STRING) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + | PARTITIONED BY (pt) + |""".stripMargin) + + var values = (0 until 500) + .map( + i => + s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)), 'p0')") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + values = (0 until 300) + .map( + i => + s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)), 'p1')") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == "diskann-vector-ann") + + assert(indexEntries.nonEmpty) + val totalRowCount = indexEntries.map(_.indexFile().rowCount()).sum + assert(totalRowCount == 800L) + } + } + + // ========== Index Write Tests ========== + + test("write vectors - large dataset") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 10000) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val output = spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + .head + assert(output.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == "diskann-vector-ann") + + val totalRowCount = indexEntries.map(_.indexFile().rowCount()).sum + assert(totalRowCount == 10000L) + } + } + + // ========== Index Read/Search Tests ========== + + test("read vectors - basic search") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 100) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + + val result = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f), 5) + |""".stripMargin) + .collect() + assert(result.map(_.getInt(0)).contains(50)) + } + } + + test("read vectors - top-k search with different k values") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 200) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + + var result = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(100.0f, 101.0f, 102.0f), 1) + |""".stripMargin) + .collect() + assert(result.length == 1) + assert(result.head.getInt(0) == 100) + + result = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(100.0f, 101.0f, 102.0f), 10) + |""".stripMargin) + .collect() + assert(result.map(_.getInt(0)).contains(100)) + } + } + + test("read vectors - multiple concurrent searches") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 500) + .map( + i => s"($i, array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + + val result1 = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(10.0f, 11.0f, 12.0f), 3) + |""".stripMargin) + .collect() + assert(result1.length == 3) + assert(result1.map(_.getInt(0)).contains(10)) + + val result2 = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(250.0f, 251.0f, 252.0f), 5) + |""".stripMargin) + .collect() + assert(result2.map(_.getInt(0)).contains(250)) + + val result3 = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(450.0f, 451.0f, 452.0f), 7) + |""".stripMargin) + .collect() + assert(result3.map(_.getInt(0)).contains(450)) + } + } + + test("read vectors - normalized vectors search") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, v ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (1 to 100) + .map { + i => + val v = math.sqrt(3.0 * i * i) + val normalized = i.toFloat / v.toFloat + s"($i, array($normalized, $normalized, $normalized))" + } + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + spark.sql( + "CALL sys.create_global_index(table => 'test.T', index_column => 'v', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + + val result = spark + .sql(""" + |SELECT * FROM vector_search('T', 'v', array(0.577f, 0.577f, 0.577f), 10) + |""".stripMargin) + .collect() + + assert(result.length == 10) + } + } + + // ========== Integration Tests ========== + + test("end-to-end: write, index, read cycle") { + withTable("T") { + spark.sql(""" + |CREATE TABLE T (id INT, name STRING, embedding ARRAY) + |TBLPROPERTIES ( + | 'bucket' = '-1', + | 'global-index.row-count-per-shard' = '10000', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |""".stripMargin) + + val values = (0 until 1000) + .map( + i => + s"($i, 'item_$i', array(cast($i as float), cast(${i + 1} as float), cast(${i + 2} as float)))") + .mkString(",") + spark.sql(s"INSERT INTO T VALUES $values") + + val indexResult = spark + .sql("CALL sys.create_global_index(table => 'test.T', index_column => 'embedding', index_type => 'diskann-vector-ann', options => 'vector.dim=3')") + .collect() + .head + assert(indexResult.getBoolean(0)) + + val table = loadTable("T") + val indexEntries = table + .store() + .newIndexFileHandler() + .scanEntries() + .asScala + .filter(_.indexFile().indexType() == "diskann-vector-ann") + assert(indexEntries.nonEmpty) + + var searchResult = spark + .sql( + """ + |SELECT id, name FROM vector_search('T', 'embedding', array(500.0f, 501.0f, 502.0f), 10) + |""".stripMargin) + .collect() + + assert(searchResult.exists(row => row.getInt(0) == 500 && row.getString(1) == "item_500")) + + searchResult = spark + .sql( + """ + |SELECT id, name FROM vector_search('T', 'embedding', array(501.0f, 502.0f, 503.0f), 10) + |""".stripMargin) + .collect() + + assert(searchResult.exists(row => row.getInt(0) == 501 && row.getString(1) == "item_501")) + } + } +} diff --git a/paimon-diskann/paimon-diskann-index/pom.xml b/paimon-diskann/paimon-diskann-index/pom.xml new file mode 100644 index 000000000000..c71637adcadb --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/pom.xml @@ -0,0 +1,108 @@ + + + + 4.0.0 + + + org.apache.paimon + paimon-diskann + 1.4-SNAPSHOT + + + paimon-diskann-index + Paimon : DiskANN Index + + + + org.apache.paimon + paimon-common + ${project.version} + + + + org.apache.paimon + paimon-diskann-jni + ${project.version} + + + + + org.junit.jupiter + junit-jupiter + ${junit5.version} + test + + + + org.apache.paimon + paimon-core + ${project.version} + test + + + + org.apache.paimon + paimon-format + ${project.version} + test + + + + org.apache.paimon + paimon-test-utils + ${project.version} + test + + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + test + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + + + + + + + maven-jar-plugin + + + + test-jar + + + + + + + diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnIndex.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnIndex.java new file mode 100644 index 000000000000..eae4aa0bc99c --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnIndex.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.diskann.Index; +import org.apache.paimon.diskann.MetricType; + +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * A wrapper class for DiskANN index with zero-copy support. + * + *

This class provides a safe Java API for interacting with native DiskANN indices using direct + * ByteBuffers for zero-copy data transfer. + */ +public class DiskAnnIndex implements Closeable { + + private final Index index; + private final int dimension; + private final int buildListSize; + private volatile boolean closed = false; + + private DiskAnnIndex(Index index, int dimension, int buildListSize) { + this.index = index; + this.dimension = dimension; + this.buildListSize = buildListSize; + } + + public static DiskAnnIndex create( + int dimension, DiskAnnVectorMetric metric, int maxDegree, int buildListSize) { + MetricType metricType = metric.toMetricType(); + Index index = Index.create(dimension, metricType, 0, maxDegree, buildListSize); + return new DiskAnnIndex(index, dimension, buildListSize); + } + + public void add(ByteBuffer vectorBuffer, int n) { + ensureOpen(); + validateVectorBuffer(vectorBuffer, n); + index.add(n, vectorBuffer); + } + + /** + * Build the index graph after adding vectors. + * + *

Uses the buildListSize parameter that was specified during index creation. + */ + public void build() { + ensureOpen(); + index.build(buildListSize); + } + + /** Return the number of bytes needed for serialization. */ + public long serializeSize() { + ensureOpen(); + return index.serializeSize(); + } + + /** + * Serialize this index with its Vamana graph adjacency lists into the given direct ByteBuffer. + * + *

The serialized data (graph + vectors, no header) is later split into an index file (graph + * only) and a data file (raw vectors) by the writer, then loaded by {@link + * DiskAnnVectorGlobalIndexReader} for search. Metadata is stored in {@link DiskAnnIndexMeta}. + * + * @param buffer a direct ByteBuffer of at least {@link #serializeSize()} bytes + * @return the number of bytes written + */ + public long serialize(ByteBuffer buffer) { + ensureOpen(); + if (!buffer.isDirect()) { + throw new IllegalArgumentException("Buffer must be a direct buffer"); + } + return index.serialize(buffer); + } + + /** + * Train a PQ codebook on the vectors in this index and encode all vectors. + * + * @param numSubspaces number of PQ subspaces (M). + * @param maxSamples maximum training samples for K-Means. + * @param kmeansIters number of K-Means iterations. + * @return {@code byte[2]}: [0] = serialized pivots, [1] = serialized compressed codes. + */ + public byte[][] pqTrainAndEncode(int numSubspaces, int maxSamples, int kmeansIters) { + ensureOpen(); + return index.pqTrainAndEncode(numSubspaces, maxSamples, kmeansIters); + } + + public static ByteBuffer allocateVectorBuffer(int numVectors, int dimension) { + return ByteBuffer.allocateDirect(numVectors * dimension * Float.BYTES) + .order(ByteOrder.nativeOrder()); + } + + private void validateVectorBuffer(ByteBuffer buffer, int numVectors) { + if (!buffer.isDirect()) { + throw new IllegalArgumentException("Vector buffer must be a direct buffer"); + } + int requiredBytes = numVectors * dimension * Float.BYTES; + if (buffer.capacity() < requiredBytes) { + throw new IllegalArgumentException( + "Vector buffer too small: required " + + requiredBytes + + " bytes, got " + + buffer.capacity()); + } + } + + private void ensureOpen() { + if (closed) { + throw new IllegalStateException("Index has been closed"); + } + } + + @Override + public void close() { + if (!closed) { + synchronized (this) { + if (!closed) { + index.close(); + closed = true; + } + } + } + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnIndexMeta.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnIndexMeta.java new file mode 100644 index 000000000000..45cf9f86dfa6 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnIndexMeta.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.Serializable; + +/** + * Metadata for DiskANN vector index. + * + *

Stores the file names of the companion files that live alongside the index file: + * + *

    + *
  • {@link #dataFileName()} — raw vector data file + *
  • {@link #pqPivotsFileName()} — PQ codebook (pivots) + *
  • {@link #pqCompressedFileName()} — PQ compressed codes + *
+ */ +public class DiskAnnIndexMeta implements Serializable { + + private static final long serialVersionUID = 1L; + + private static final int VERSION = 1; + + private final int dim; + private final int metricValue; + private final int indexTypeValue; + private final long numVectors; + private final long minId; + private final long maxId; + private final int maxDegree; + private final int buildListSize; + private final int startId; + private final String dataFileName; + private final String pqPivotsFileName; + private final String pqCompressedFileName; + + public DiskAnnIndexMeta( + int dim, + int metricValue, + int indexTypeValue, + long numVectors, + long minId, + long maxId, + int maxDegree, + int buildListSize, + int startId, + String dataFileName, + String pqPivotsFileName, + String pqCompressedFileName) { + this.dim = dim; + this.metricValue = metricValue; + this.indexTypeValue = indexTypeValue; + this.numVectors = numVectors; + this.minId = minId; + this.maxId = maxId; + this.maxDegree = maxDegree; + this.buildListSize = buildListSize; + this.startId = startId; + this.dataFileName = dataFileName; + this.pqPivotsFileName = pqPivotsFileName; + this.pqCompressedFileName = pqCompressedFileName; + } + + public int dim() { + return dim; + } + + public int metricValue() { + return metricValue; + } + + public long numVectors() { + return numVectors; + } + + public long minId() { + return minId; + } + + public long maxId() { + return maxId; + } + + public int maxDegree() { + return maxDegree; + } + + public int buildListSize() { + return buildListSize; + } + + public int startId() { + return startId; + } + + /** The file name of the separate vector data file. */ + public String dataFileName() { + return dataFileName; + } + + /** The file name of the PQ codebook (pivots) file. */ + public String pqPivotsFileName() { + return pqPivotsFileName; + } + + /** The file name of the PQ compressed codes file. */ + public String pqCompressedFileName() { + return pqCompressedFileName; + } + + /** Serialize metadata to byte array. */ + public byte[] serialize() throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(baos); + out.writeInt(VERSION); + out.writeInt(dim); + out.writeInt(metricValue); + out.writeInt(indexTypeValue); + out.writeLong(numVectors); + out.writeLong(minId); + out.writeLong(maxId); + out.writeInt(maxDegree); + out.writeInt(buildListSize); + out.writeInt(startId); + out.writeUTF(dataFileName); + out.writeUTF(pqPivotsFileName); + out.writeUTF(pqCompressedFileName); + out.flush(); + return baos.toByteArray(); + } + + /** Deserialize metadata from byte array. */ + public static DiskAnnIndexMeta deserialize(byte[] data) throws IOException { + DataInputStream in = new DataInputStream(new ByteArrayInputStream(data)); + int version = in.readInt(); + if (version != VERSION) { + throw new IOException("Unsupported DiskANN index meta version: " + version); + } + int dim = in.readInt(); + int metricValue = in.readInt(); + int indexTypeValue = in.readInt(); + long numVectors = in.readLong(); + long minId = in.readLong(); + long maxId = in.readLong(); + int maxDegree = in.readInt(); + int buildListSize = in.readInt(); + int startId = in.readInt(); + String dataFileName = in.readUTF(); + String pqPivotsFileName = in.readUTF(); + String pqCompressedFileName = in.readUTF(); + return new DiskAnnIndexMeta( + dim, + metricValue, + indexTypeValue, + numVectors, + minId, + maxId, + maxDegree, + buildListSize, + startId, + dataFileName, + pqPivotsFileName, + pqCompressedFileName); + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnScoredGlobalIndexResult.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnScoredGlobalIndexResult.java new file mode 100644 index 000000000000..c5bf687c02d1 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnScoredGlobalIndexResult.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.globalindex.ScoreGetter; +import org.apache.paimon.globalindex.ScoredGlobalIndexResult; +import org.apache.paimon.utils.RoaringNavigableMap64; + +import java.util.HashMap; + +/** Vector search global index result for DiskANN vector index. */ +public class DiskAnnScoredGlobalIndexResult implements ScoredGlobalIndexResult { + + private final HashMap id2scores; + private final RoaringNavigableMap64 results; + + public DiskAnnScoredGlobalIndexResult( + RoaringNavigableMap64 results, HashMap id2scores) { + this.id2scores = id2scores; + this.results = results; + } + + @Override + public ScoreGetter scoreGetter() { + return id2scores::get; + } + + @Override + public RoaringNavigableMap64 results() { + return this.results; + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexReader.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexReader.java new file mode 100644 index 000000000000..44ad2955594a --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexReader.java @@ -0,0 +1,525 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.diskann.IndexSearcher; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.SeekableInputStream; +import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexReader; +import org.apache.paimon.globalindex.GlobalIndexResult; +import org.apache.paimon.globalindex.io.GlobalIndexFileReader; +import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.predicate.VectorSearch; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.FloatType; +import org.apache.paimon.utils.IOUtils; +import org.apache.paimon.utils.RoaringNavigableMap64; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Optional; +import java.util.PriorityQueue; + +/** + * Vector global index reader using DiskANN. + * + *

This implementation uses DiskANN for efficient approximate nearest neighbor search. Both the + * Vamana graph and full-precision vectors are read on demand from Paimon FileIO-backed storage + * (local, HDFS, S3, OSS, etc.) via {@link SeekableInputStream}, ensuring that neither is loaded + * into Java memory in full. + */ +public class DiskAnnVectorGlobalIndexReader implements GlobalIndexReader { + + /** + * Loaded search handles. Each entry wraps a DiskANN {@link IndexSearcher} (Rust native beam + * search with both graph and vectors read on-demand from FileIO-backed storage via {@link + * FileIOGraphReader} and {@link FileIOVectorReader}). + */ + private final List handles; + + private final List indexMetas; + private final List ioMetas; + private final GlobalIndexFileReader fileReader; + private final DataType fieldType; + private final DiskAnnVectorIndexOptions options; + private volatile boolean metasLoaded = false; + private volatile boolean indicesLoaded = false; + + /** + * Number of vectors to cache per searcher in the LRU cache inside {@link FileIOVectorReader}. + */ + private static final int VECTOR_CACHE_SIZE = 4096; + + public DiskAnnVectorGlobalIndexReader( + GlobalIndexFileReader fileReader, + List ioMetas, + DataType fieldType, + DiskAnnVectorIndexOptions options) { + this.fileReader = fileReader; + this.ioMetas = ioMetas; + this.fieldType = fieldType; + this.options = options; + this.handles = new ArrayList<>(); + this.indexMetas = new ArrayList<>(); + } + + /** Wrapper around a search implementation for lifecycle management. */ + private interface SearchHandle extends AutoCloseable { + void search( + long n, + float[] queryVectors, + int k, + int searchListSize, + float[] distances, + long[] labels); + + @Override + void close() throws IOException; + } + + /** + * Uses DiskANN's native Rust beam search via {@link IndexSearcher}. Both graph and vectors are + * read on demand from Paimon FileIO-backed storage through {@link FileIOGraphReader} and {@link + * FileIOVectorReader} JNI callbacks. + */ + private static class DiskAnnSearchHandle implements SearchHandle { + private final IndexSearcher searcher; + + DiskAnnSearchHandle(IndexSearcher searcher) { + this.searcher = searcher; + } + + @Override + public void search( + long n, + float[] queryVectors, + int k, + int searchListSize, + float[] distances, + long[] labels) { + searcher.search(n, queryVectors, k, searchListSize, distances, labels); + } + + @Override + public void close() { + searcher.close(); + } + } + + @Override + public Optional visitVectorSearch(VectorSearch vectorSearch) { + try { + ensureLoadMetas(); + + RoaringNavigableMap64 includeRowIds = vectorSearch.includeRowIds(); + if (includeRowIds != null) { + List matchingIndices = new ArrayList<>(); + for (int i = 0; i < indexMetas.size(); i++) { + DiskAnnIndexMeta meta = indexMetas.get(i); + if (hasOverlap(meta.minId(), meta.maxId(), includeRowIds)) { + matchingIndices.add(i); + } + } + if (matchingIndices.isEmpty()) { + return Optional.empty(); + } + ensureLoadIndices(matchingIndices); + } else { + ensureLoadAllIndices(); + } + + return Optional.ofNullable(search(vectorSearch)); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to search DiskANN vector index with fieldName=%s, limit=%d", + vectorSearch.fieldName(), vectorSearch.limit()), + e); + } + } + + private boolean hasOverlap(long minId, long maxId, RoaringNavigableMap64 includeRowIds) { + for (Long id : includeRowIds) { + if (id >= minId && id <= maxId) { + return true; + } + if (id > maxId) { + break; + } + } + return false; + } + + private GlobalIndexResult search(VectorSearch vectorSearch) throws IOException { + validateVectorType(vectorSearch.vector()); + float[] queryVector = ((float[]) vectorSearch.vector()).clone(); + int limit = vectorSearch.limit(); + + PriorityQueue result = + new PriorityQueue<>(Comparator.comparingDouble(sr -> sr.score)); + + RoaringNavigableMap64 includeRowIds = vectorSearch.includeRowIds(); + int searchK = limit; + if (includeRowIds != null) { + searchK = + Math.max( + limit * options.searchFactor(), + (int) includeRowIds.getLongCardinality()); + } + + for (SearchHandle handle : handles) { + if (handle == null) { + continue; + } + int effectiveK = searchK; + if (effectiveK <= 0) { + continue; + } + + float[] distances = new float[effectiveK]; + long[] labels = new long[effectiveK]; + + // Dynamic search list sizing: use max of configured value and effectiveK + // This follows Milvus best practice: search_list should be >= topk + int dynamicSearchListSize = Math.max(options.searchListSize(), effectiveK); + handle.search(1, queryVector, effectiveK, dynamicSearchListSize, distances, labels); + + for (int i = 0; i < effectiveK; i++) { + long rowId = labels[i]; + if (rowId < 0) { + continue; + } + if (includeRowIds != null && !includeRowIds.contains(rowId)) { + continue; + } + float score = convertDistanceToScore(distances[i]); + + if (result.size() < limit) { + result.offer(new ScoredRow(rowId, score)); + } else { + if (result.peek() != null && score > result.peek().score) { + result.poll(); + result.offer(new ScoredRow(rowId, score)); + } + } + } + } + + RoaringNavigableMap64 roaringBitmap64 = new RoaringNavigableMap64(); + HashMap id2scores = new HashMap<>(result.size()); + for (ScoredRow scoredRow : result) { + id2scores.put(scoredRow.rowId, scoredRow.score); + roaringBitmap64.add(scoredRow.rowId); + } + return new DiskAnnScoredGlobalIndexResult(roaringBitmap64, id2scores); + } + + private float convertDistanceToScore(float distance) { + if (options.metric() == DiskAnnVectorMetric.L2 + || options.metric() == DiskAnnVectorMetric.COSINE) { + return 1.0f / (1.0f + distance); + } else { + return distance; + } + } + + private void validateVectorType(Object vector) { + if (!(vector instanceof float[])) { + throw new IllegalArgumentException( + "Expected float[] vector but got: " + vector.getClass()); + } + if (!(fieldType instanceof ArrayType) + || !(((ArrayType) fieldType).getElementType() instanceof FloatType)) { + throw new IllegalArgumentException( + "DiskANN currently only supports float arrays, but field type is: " + + fieldType); + } + } + + private void ensureLoadMetas() throws IOException { + if (!metasLoaded) { + synchronized (this) { + if (!metasLoaded) { + for (GlobalIndexIOMeta ioMeta : ioMetas) { + byte[] metaBytes = ioMeta.metadata(); + DiskAnnIndexMeta meta = DiskAnnIndexMeta.deserialize(metaBytes); + indexMetas.add(meta); + } + metasLoaded = true; + } + } + } + } + + private void ensureLoadAllIndices() throws IOException { + if (!indicesLoaded) { + synchronized (this) { + if (!indicesLoaded) { + for (int i = 0; i < ioMetas.size(); i++) { + loadIndexAt(i); + } + indicesLoaded = true; + } + } + } + } + + private void ensureLoadIndices(List positions) throws IOException { + synchronized (this) { + while (handles.size() < ioMetas.size()) { + handles.add(null); + } + for (int pos : positions) { + if (handles.get(pos) == null) { + loadIndexAt(pos); + } + } + } + } + + /** + * Load an index at the given position. + * + *

The index file (graph) and the data file (vectors) are accessed on demand via {@link + * SeekableInputStream}s — neither is loaded into Java memory in full. The PQ pivots and + * compressed codes are loaded into memory as the "memory thumbnail" for approximate distance + * computation during native beam search. + */ + private void loadIndexAt(int position) throws IOException { + GlobalIndexIOMeta ioMeta = ioMetas.get(position); + DiskAnnIndexMeta meta = indexMetas.get(position); + SearchHandle handle = null; + try { + // 1. Open index file (graph only, no header) as a SeekableInputStream. + // FileIOGraphReader scans the graph section + builds offset index; graph neighbors + // are read on demand during beam search. + // numNodes = user vectors + 1 start point. + int numNodes = (int) meta.numVectors() + 1; + SeekableInputStream graphStream = fileReader.getInputStream(ioMeta); + FileIOGraphReader graphReader = + new FileIOGraphReader( + graphStream, + meta.dim(), + meta.metricValue(), + meta.maxDegree(), + meta.buildListSize(), + numNodes, + meta.startId(), + VECTOR_CACHE_SIZE); + + // 2. Open data file stream for on-demand full-vector reads. + Path dataPath = new Path(ioMeta.filePath().getParent(), meta.dataFileName()); + GlobalIndexIOMeta dataIOMeta = new GlobalIndexIOMeta(dataPath, 0L, new byte[0]); + SeekableInputStream vectorStream = fileReader.getInputStream(dataIOMeta); + FileIOVectorReader vectorReader = + new FileIOVectorReader(vectorStream, meta.dim(), meta.maxDegree()); + + // 3. Load PQ files into memory for in-memory approximate distance computation. + // PQ is mandatory — beam search uses PQ brute-force scan followed by + // full-precision reranking from disk. + byte[] pqPivots = loadCompanionFile(ioMeta, meta.pqPivotsFileName()); + byte[] pqCompressed = loadCompanionFile(ioMeta, meta.pqCompressedFileName()); + if (pqPivots == null || pqPivots.length == 0) { + throw new IOException( + "PQ pivots file is missing or empty for index at position " + + position + + ". PQ is required for DiskANN search. " + + "Pivots file: " + + meta.pqPivotsFileName()); + } + if (pqCompressed == null || pqCompressed.length == 0) { + throw new IOException( + "PQ compressed file is missing or empty for index at position " + + position + + ". PQ is required for DiskANN search. " + + "Compressed file: " + + meta.pqCompressedFileName()); + } + + // 4. Create DiskANN native searcher with on-demand graph + vector access + PQ. + handle = + new DiskAnnSearchHandle( + IndexSearcher.createFromReaders( + graphReader, + vectorReader, + meta.dim(), + meta.minId(), + pqPivots, + pqCompressed)); + + if (handles.size() <= position) { + while (handles.size() < position) { + handles.add(null); + } + handles.add(handle); + } else { + handles.set(position, handle); + } + } catch (Exception e) { + IOUtils.closeQuietly(handle); + throw e instanceof IOException ? (IOException) e : new IOException(e); + } + } + + @Override + public void close() throws IOException { + Throwable firstException = null; + + // Close all search handles (also closes their FileIOVectorReader streams). + for (SearchHandle handle : handles) { + if (handle == null) { + continue; + } + try { + handle.close(); + } catch (Throwable t) { + if (firstException == null) { + firstException = t; + } else { + firstException.addSuppressed(t); + } + } + } + handles.clear(); + + if (firstException != null) { + if (firstException instanceof IOException) { + throw (IOException) firstException; + } else if (firstException instanceof RuntimeException) { + throw (RuntimeException) firstException; + } else { + throw new RuntimeException( + "Failed to close DiskANN vector global index reader", firstException); + } + } + } + + /** + * Load a companion file (e.g. PQ pivots/compressed) relative to the index file. + * + * @return the file contents as byte[], or null if the file name is null/empty. + * @throws IOException if the file cannot be read. + */ + private byte[] loadCompanionFile(GlobalIndexIOMeta indexIOMeta, String fileName) + throws IOException { + if (fileName == null || fileName.isEmpty()) { + return null; + } + Path filePath = new Path(indexIOMeta.filePath().getParent(), fileName); + GlobalIndexIOMeta fileMeta = new GlobalIndexIOMeta(filePath, 0L, new byte[0]); + try (SeekableInputStream in = fileReader.getInputStream(fileMeta)) { + // Read the entire file into a byte array. + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(4096); + byte[] buf = new byte[8192]; + int n; + while ((n = in.read(buf)) >= 0) { + baos.write(buf, 0, n); + } + byte[] data = baos.toByteArray(); + return data.length > 0 ? data : null; + } + } + + private static class ScoredRow { + final long rowId; + final float score; + + ScoredRow(long rowId, float score) { + this.rowId = rowId; + this.score = score; + } + } + + // =================== unsupported ===================== + + @Override + public Optional visitIsNotNull(FieldRef fieldRef) { + return Optional.empty(); + } + + @Override + public Optional visitIsNull(FieldRef fieldRef) { + return Optional.empty(); + } + + @Override + public Optional visitStartsWith(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitEndsWith(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitContains(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitLike(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitLessThan(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitGreaterOrEqual(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitNotEqual(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitLessOrEqual(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitEqual(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitGreaterThan(FieldRef fieldRef, Object literal) { + return Optional.empty(); + } + + @Override + public Optional visitIn(FieldRef fieldRef, List literals) { + return Optional.empty(); + } + + @Override + public Optional visitNotIn(FieldRef fieldRef, List literals) { + return Optional.empty(); + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexWriter.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexWriter.java new file mode 100644 index 000000000000..016eed11d867 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexWriter.java @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.data.InternalArray; +import org.apache.paimon.globalindex.GlobalIndexSingletonWriter; +import org.apache.paimon.globalindex.ResultEntry; +import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.FloatType; + +import java.io.BufferedOutputStream; +import java.io.Closeable; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.ArrayList; +import java.util.List; + +/** + * Vector global index writer using DiskANN. + * + *

The build pipeline follows two phases: + * + *

    + *
  1. Vamana Graph Construction — vectors are added in batches and the Vamana graph (with + * alpha-pruning) is built via the native DiskANN library. + *
  2. PQ Compression — after the graph is built, a Product Quantization codebook is + * trained via the native DiskANN library and all vectors are compressed to compact PQ codes. + *
+ * + *

For each index flush, four files are produced: + * + *

    + *
  • {@code .index} — Vamana graph (adjacency lists only, no header) + *
  • {@code .data} — raw vectors stored sequentially (position = ID) + *
  • {@code .pq_pivots} — PQ codebook (centroids) + *
  • {@code .pq_compressed} — PQ compressed codes (memory thumbnail) + *
+ * + *

The PQ compressed data acts as a "memory thumbnail" — during search it stays resident in + * memory and allows fast approximate distance computation, reducing disk I/O for full vectors. + * + *

This class implements {@link Closeable} so that the native DiskANN index is released even if + * {@link #finish()} is never called or throws an exception. + */ +public class DiskAnnVectorGlobalIndexWriter implements GlobalIndexSingletonWriter, Closeable { + + private static final int DEFAULT_BATCH_SIZE = 10000; + + private final GlobalIndexFileWriter fileWriter; + private final DiskAnnVectorIndexOptions options; + private final int sizePerIndex; + private final int batchSize; + private final int dim; + private final DataType fieldType; + + private long count = 0; + private long currentIndexCount = 0; + private long currentIndexMinId = Long.MAX_VALUE; + private long currentIndexMaxId = Long.MIN_VALUE; + private final List pendingBatch; + private final List results; + private DiskAnnIndex currentIndex; + private boolean built = false; + + public DiskAnnVectorGlobalIndexWriter( + GlobalIndexFileWriter fileWriter, + DataType fieldType, + DiskAnnVectorIndexOptions options) { + this.fileWriter = fileWriter; + this.fieldType = fieldType; + this.options = options; + this.sizePerIndex = options.sizePerIndex(); + this.batchSize = Math.min(DEFAULT_BATCH_SIZE, sizePerIndex); + this.dim = options.dimension(); + this.pendingBatch = new ArrayList<>(batchSize); + this.results = new ArrayList<>(); + + validateFieldType(fieldType); + } + + private void validateFieldType(DataType dataType) { + if (!(dataType instanceof ArrayType)) { + throw new IllegalArgumentException( + "DiskANN vector index requires ArrayType, but got: " + dataType); + } + DataType elementType = ((ArrayType) dataType).getElementType(); + if (!(elementType instanceof FloatType)) { + throw new IllegalArgumentException( + "DiskANN vector index requires float array, but got: " + elementType); + } + } + + @Override + public void write(Object fieldData) { + float[] vector; + if (fieldData instanceof float[]) { + vector = (float[]) fieldData; + } else if (fieldData instanceof InternalArray) { + vector = ((InternalArray) fieldData).toFloatArray(); + } else { + throw new RuntimeException( + "Unsupported vector type: " + fieldData.getClass().getName()); + } + checkDimension(vector); + currentIndexMinId = Math.min(currentIndexMinId, count); + currentIndexMaxId = Math.max(currentIndexMaxId, count); + pendingBatch.add(new VectorEntry(count, vector)); + count++; + + try { + if (pendingBatch.size() >= batchSize) { + addBatchToIndex(); + } + if (currentIndexCount >= sizePerIndex) { + flushCurrentIndex(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public List finish() { + try { + if (!pendingBatch.isEmpty()) { + addBatchToIndex(); + } + if (currentIndex != null && currentIndexCount > 0) { + flushCurrentIndex(); + } + return results; + } catch (IOException e) { + throw new RuntimeException("Failed to write DiskANN vector global index", e); + } + } + + private void addBatchToIndex() throws IOException { + if (pendingBatch.isEmpty()) { + return; + } + + if (currentIndex == null) { + currentIndex = createIndex(); + built = false; + } + + int n = pendingBatch.size(); + ByteBuffer vectorBuffer = DiskAnnIndex.allocateVectorBuffer(n, dim); + FloatBuffer floatView = vectorBuffer.asFloatBuffer(); + + for (int i = 0; i < n; i++) { + VectorEntry entry = pendingBatch.get(i); + float[] vector = entry.vector; + for (int j = 0; j < dim; j++) { + floatView.put(i * dim + j, vector[j]); + } + } + + currentIndex.add(vectorBuffer, n); + currentIndexCount += n; + pendingBatch.clear(); + built = false; + } + + private void flushCurrentIndex() throws IOException { + if (currentIndex == null || currentIndexCount == 0) { + return; + } + + // ---- Phase 2: Vamana graph construction ---- + if (!built) { + currentIndex.build(); + built = true; + } + + // Serialize the graph + vectors into one buffer (no header — metadata goes to + // DiskAnnIndexMeta). + long serializeSize = currentIndex.serializeSize(); + if (serializeSize > Integer.MAX_VALUE) { + throw new IOException( + "Index too large to serialize: " + + serializeSize + + " bytes exceeds maximum buffer size"); + } + + ByteBuffer serializeBuffer = + ByteBuffer.allocateDirect((int) serializeSize).order(ByteOrder.nativeOrder()); + long bytesWritten = currentIndex.serialize(serializeBuffer); + + byte[] fullData = new byte[(int) bytesWritten]; + serializeBuffer.rewind(); + serializeBuffer.get(fullData); + + // Compute split point: data section = user vectors only (no start point). + int dataSectionSize = (int) (currentIndexCount * dim * Float.BYTES); + int graphSectionSize = fullData.length - dataSectionSize; + + // Generate file names — all share the same base name. + String indexFileName = fileWriter.newFileName(DiskAnnVectorGlobalIndexerFactory.IDENTIFIER); + String baseName = indexFileName.replaceAll("\\.index$", ""); + String dataFileName = baseName + ".data"; + String pqPivotsFileName = baseName + ".pq_pivots"; + String pqCompressedFileName = baseName + ".pq_compressed"; + + // Write index file: graph section only (no header). + try (OutputStream out = + new BufferedOutputStream(fileWriter.newOutputStream(indexFileName))) { + out.write(fullData, 0, graphSectionSize); + out.flush(); + } + + // Write data file: raw vectors in sequential order (position = ID). + try (OutputStream out = + new BufferedOutputStream(fileWriter.newOutputStream(dataFileName))) { + out.write(fullData, graphSectionSize, dataSectionSize); + out.flush(); + } + + // ---- Phase 1: PQ Compression & Training (native) ---- + // Train PQ codebook on the vectors stored in the native index and encode all vectors. + byte[][] pqResult = + currentIndex.pqTrainAndEncode( + options.pqSubspaces(), + options.pqSampleSize(), + options.pqKmeansIterations()); + + // Write PQ pivots file (codebook). + try (OutputStream out = + new BufferedOutputStream(fileWriter.newOutputStream(pqPivotsFileName))) { + out.write(pqResult[0]); + out.flush(); + } + + // Write PQ compressed file (memory thumbnail). + try (OutputStream out = + new BufferedOutputStream(fileWriter.newOutputStream(pqCompressedFileName))) { + out.write(pqResult[1]); + out.flush(); + } + + // Build metadata with all companion file names and graph parameters. + DiskAnnIndexMeta meta = + new DiskAnnIndexMeta( + dim, + options.metric().toMetricType().value(), + 0, + currentIndexCount, + currentIndexMinId, + currentIndexMaxId, + options.maxDegree(), + options.buildListSize(), + 0, // startId is always 0 (START_POINT_ID) + dataFileName, + pqPivotsFileName, + pqCompressedFileName); + results.add(new ResultEntry(indexFileName, currentIndexCount, meta.serialize())); + + currentIndex.close(); + currentIndex = null; + currentIndexCount = 0; + currentIndexMinId = Long.MAX_VALUE; + currentIndexMaxId = Long.MIN_VALUE; + built = false; + } + + private DiskAnnIndex createIndex() { + return DiskAnnIndex.create( + options.dimension(), + options.metric(), + options.maxDegree(), + options.buildListSize()); + } + + private void checkDimension(float[] vector) { + if (vector.length != options.dimension()) { + throw new IllegalArgumentException( + String.format( + "Vector dimension mismatch: expected %d, but got %d", + options.dimension(), vector.length)); + } + } + + /** + * Release native resources held by the current in-progress index. + * + *

This is a safety net: under normal operation the index is closed by {@link + * #flushCurrentIndex()}, but if an error occurs before flushing this method ensures the native + * handle is freed. + */ + @Override + public void close() { + if (currentIndex != null) { + currentIndex.close(); + currentIndex = null; + } + } + + /** Entry holding a vector and its row ID. */ + private static class VectorEntry { + final long id; + final float[] vector; + + VectorEntry(long id, float[] vector) { + this.id = id; + this.vector = vector; + } + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexer.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexer.java new file mode 100644 index 000000000000..4efc00d3a088 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexReader; +import org.apache.paimon.globalindex.GlobalIndexWriter; +import org.apache.paimon.globalindex.GlobalIndexer; +import org.apache.paimon.globalindex.io.GlobalIndexFileReader; +import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.options.Options; +import org.apache.paimon.types.DataType; + +import java.util.List; + +/** DiskANN vector global indexer. */ +public class DiskAnnVectorGlobalIndexer implements GlobalIndexer { + + private final DataType fieldType; + private final DiskAnnVectorIndexOptions options; + + public DiskAnnVectorGlobalIndexer(DataType fieldType, Options options) { + this.fieldType = fieldType; + this.options = new DiskAnnVectorIndexOptions(options); + } + + @Override + public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) { + return new DiskAnnVectorGlobalIndexWriter(fileWriter, fieldType, options); + } + + @Override + public GlobalIndexReader createReader( + GlobalIndexFileReader fileReader, List files) { + return new DiskAnnVectorGlobalIndexReader(fileReader, files, fieldType, options); + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexerFactory.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexerFactory.java new file mode 100644 index 000000000000..df3d7274744c --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexerFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.globalindex.GlobalIndexer; +import org.apache.paimon.globalindex.GlobalIndexerFactory; +import org.apache.paimon.options.Options; +import org.apache.paimon.types.DataField; + +/** Factory for creating DiskANN vector index. */ +public class DiskAnnVectorGlobalIndexerFactory implements GlobalIndexerFactory { + + public static final String IDENTIFIER = "diskann-vector-ann"; + + @Override + public String identifier() { + return IDENTIFIER; + } + + @Override + public GlobalIndexer create(DataField field, Options options) { + return new DiskAnnVectorGlobalIndexer(field.type(), options); + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorIndexOptions.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorIndexOptions.java new file mode 100644 index 000000000000..373f087cee5a --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorIndexOptions.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.options.ConfigOption; +import org.apache.paimon.options.ConfigOptions; +import org.apache.paimon.options.Options; + +/** Options for DiskANN vector index. */ +public class DiskAnnVectorIndexOptions { + + public static final ConfigOption VECTOR_DIM = + ConfigOptions.key("vector.dim") + .intType() + .defaultValue(128) + .withDescription("The dimension of the vector"); + + public static final ConfigOption VECTOR_METRIC = + ConfigOptions.key("vector.metric") + .enumType(DiskAnnVectorMetric.class) + .defaultValue(DiskAnnVectorMetric.L2) + .withDescription( + "The similarity metric for vector search (L2, INNER_PRODUCT, COSINE), and L2 is the default"); + + public static final ConfigOption VECTOR_MAX_DEGREE = + ConfigOptions.key("vector.diskann.max-degree") + .intType() + .defaultValue(64) + .withDescription("The maximum degree (R) for DiskANN graph construction"); + + public static final ConfigOption VECTOR_BUILD_LIST_SIZE = + ConfigOptions.key("vector.diskann.build-list-size") + .intType() + .defaultValue(100) + .withDescription("The build list size (L) for DiskANN index construction"); + + public static final ConfigOption VECTOR_SEARCH_LIST_SIZE = + ConfigOptions.key("vector.diskann.search-list-size") + .intType() + .defaultValue(100) + .withDescription("The search list size (L) for DiskANN query"); + + public static final ConfigOption VECTOR_SIZE_PER_INDEX = + ConfigOptions.key("vector.size-per-index") + .intType() + .defaultValue(200_0000) + .withDescription("The size of vectors stored in each vector index file"); + + public static final ConfigOption VECTOR_SEARCH_FACTOR = + ConfigOptions.key("vector.search-factor") + .intType() + .defaultValue(10) + .withDescription( + "The multiplier for the search limit when filtering is applied. " + + "This is used to fetch more results to ensure enough records after filtering."); + + public static final ConfigOption VECTOR_PQ_SUBSPACES = + ConfigOptions.key("vector.diskann.pq-subspaces") + .intType() + .defaultValue(-1) + .withDescription( + "Number of subspaces (M) for Product Quantization. " + + "Dimension must be divisible by M. " + + "Default (-1) auto-computes as max(1, dim/4)."); + + public static final ConfigOption VECTOR_PQ_KMEANS_ITERATIONS = + ConfigOptions.key("vector.diskann.pq-kmeans-iterations") + .intType() + .defaultValue(20) + .withDescription("Number of K-Means iterations for PQ codebook training."); + + public static final ConfigOption VECTOR_PQ_SAMPLE_SIZE = + ConfigOptions.key("vector.diskann.pq-sample-size") + .intType() + .defaultValue(100_000) + .withDescription("Maximum number of vectors sampled for PQ codebook training."); + + private final int dimension; + private final DiskAnnVectorMetric metric; + private final int maxDegree; + private final int buildListSize; + private final int searchListSize; + private final int sizePerIndex; + private final int searchFactor; + private final int pqSubspaces; + private final int pqKmeansIterations; + private final int pqSampleSize; + + public DiskAnnVectorIndexOptions(Options options) { + this.dimension = options.get(VECTOR_DIM); + this.metric = options.get(VECTOR_METRIC); + this.maxDegree = options.get(VECTOR_MAX_DEGREE); + this.buildListSize = options.get(VECTOR_BUILD_LIST_SIZE); + this.searchListSize = options.get(VECTOR_SEARCH_LIST_SIZE); + this.sizePerIndex = + options.get(VECTOR_SIZE_PER_INDEX) > 0 + ? options.get(VECTOR_SIZE_PER_INDEX) + : VECTOR_SIZE_PER_INDEX.defaultValue(); + this.searchFactor = options.get(VECTOR_SEARCH_FACTOR); + + int rawPqSub = options.get(VECTOR_PQ_SUBSPACES); + this.pqSubspaces = rawPqSub > 0 ? rawPqSub : defaultNumSubspaces(dimension); + this.pqKmeansIterations = options.get(VECTOR_PQ_KMEANS_ITERATIONS); + this.pqSampleSize = options.get(VECTOR_PQ_SAMPLE_SIZE); + } + + public int dimension() { + return dimension; + } + + public DiskAnnVectorMetric metric() { + return metric; + } + + public int maxDegree() { + return maxDegree; + } + + public int buildListSize() { + return buildListSize; + } + + public int searchListSize() { + return searchListSize; + } + + public int sizePerIndex() { + return sizePerIndex; + } + + public int searchFactor() { + return searchFactor; + } + + /** Number of PQ subspaces (M). */ + public int pqSubspaces() { + return pqSubspaces; + } + + /** Number of K-Means iterations for PQ training. */ + public int pqKmeansIterations() { + return pqKmeansIterations; + } + + /** Maximum number of training samples for PQ. */ + public int pqSampleSize() { + return pqSampleSize; + } + + /** + * Compute a reasonable default number of PQ subspaces for the given dimension. The result is + * the largest divisor of {@code dim} that is {@code <= dim / 4} and at least 1. + */ + static int defaultNumSubspaces(int dim) { + int target = Math.max(1, dim / 4); + while (target > 1 && dim % target != 0) { + target--; + } + return target; + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorMetric.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorMetric.java new file mode 100644 index 000000000000..6336c3b7d5e9 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/DiskAnnVectorMetric.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.diskann.MetricType; + +/** Metric type for DiskANN vector index. */ +public enum DiskAnnVectorMetric { + L2(MetricType.L2), + INNER_PRODUCT(MetricType.INNER_PRODUCT), + COSINE(MetricType.COSINE); + + private final MetricType metricType; + + DiskAnnVectorMetric(MetricType metricType) { + this.metricType = metricType; + } + + MetricType toMetricType() { + return metricType; + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/FileIOGraphReader.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/FileIOGraphReader.java new file mode 100644 index 000000000000..4fbe9d964d72 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/FileIOGraphReader.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.fs.SeekableInputStream; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Reads graph structure (neighbors) from a DiskANN index file on demand via a Paimon {@link + * SeekableInputStream}. + * + *

The underlying stream can be backed by any Paimon FileIO provider — local, HDFS, S3, OSS, etc. + * This class adds an LRU cache so that repeated reads for the same node (common during DiskANN's + * beam search) do not trigger redundant I/O. + * + *

The Rust JNI layer invokes {@link #readNeighbors(int)} via reflection during DiskANN's native + * beam search. It also calls getter methods ({@link #getDimension()}, {@link #getCount()}, etc.) + * during searcher initialization. + * + *

Index file layout (graph only, no header)

+ * + *
+ *   Graph section: for each node (count nodes):
+ *     int_id       : i32
+ *     neighbor_cnt : i32
+ *     neighbors    : neighbor_cnt × i32
+ * 
+ * + *

DiskANN stores vectors sequentially — the position IS the ID. Internal IDs map to positions + * via {@code position = int_id - 1} for user vectors. The start point ({@code int_id == startId}) + * is not a user vector. + * + *

All metadata (dimension, metric, max_degree, etc.) is provided externally via {@link + * DiskAnnIndexMeta} — the file contains only graph data. + * + *

On construction, the graph section is scanned once sequentially to build an offset index + * (mapping internal node ID → file byte offset for its neighbor data). After that, individual + * neighbor lists are read on demand by seeking to the stored offset. + */ +public class FileIOGraphReader implements Closeable { + + /** Source stream — must support seek(). */ + private final SeekableInputStream input; + + // ---- Metadata fields (from DiskAnnIndexMeta) ---- + private final int dimension; + private final int metricValue; + private final int maxDegree; + private final int buildListSize; + private final int count; + private final int startId; + + // ---- Offset index built during initial scan ---- + + /** Mapping from internal node ID → byte offset of the node's neighbor_cnt field in the file. */ + private final Map nodeNeighborOffsets; + + /** LRU cache: internal node ID → neighbor list (int[]). */ + private final LinkedHashMap cache; + + /** + * Create a reader from metadata and a seekable input stream. + * + *

The stream should point to a file that contains ONLY the graph section (no header, no + * IDs). All metadata is supplied via parameters (originally from {@link DiskAnnIndexMeta}). + * + * @param input seekable input stream for the index file (graph only) + * @param dimension vector dimension + * @param metricValue metric type value (0=L2, 1=IP, 2=Cosine) + * @param maxDegree maximum adjacency list size + * @param buildListSize search list size used during construction + * @param count total number of graph nodes (including start point) + * @param startId internal ID of the graph start/entry point + * @param cacheSize maximum number of cached neighbor lists (0 uses default 4096) + * @throws IOException if reading or parsing fails + */ + public FileIOGraphReader( + SeekableInputStream input, + int dimension, + int metricValue, + int maxDegree, + int buildListSize, + int count, + int startId, + int cacheSize) + throws IOException { + this.input = input; + this.dimension = dimension; + this.metricValue = metricValue; + this.maxDegree = maxDegree; + this.buildListSize = buildListSize; + this.count = count; + this.startId = startId; + + // Scan graph section to build offset index. + // The file starts directly with graph entries (no header). + // Each entry: int_id(4) + neighbor_cnt(4) + neighbors(cnt*4). + this.nodeNeighborOffsets = new HashMap<>(count); + + // Reusable buffer for reading int_id(4) + neighbor_cnt(4) = 8 bytes per node. + byte[] nodeBuf = new byte[8]; + long filePos = 0; + + for (int i = 0; i < count; i++) { + input.seek(filePos); + readFully(input, nodeBuf); + + int intId = readInt(nodeBuf, 0); + int neighborCount = readInt(nodeBuf, 4); + + // Store file offset pointing to the neighbor_cnt field (so readNeighbors can re-read + // count + data). + nodeNeighborOffsets.put(intId, filePos + 4); + + // Advance past: int_id(4) + neighbor_cnt(4) + neighbors(cnt*4). + filePos += 8 + (long) neighborCount * 4; + } + + // Create LRU cache. + final int cap = cacheSize > 0 ? cacheSize : 4096; + this.cache = + new LinkedHashMap(cap, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > cap; + } + }; + } + + // ---- Header accessors (called by Rust JNI during initialization) ---- + + public int getDimension() { + return dimension; + } + + public int getMetricValue() { + return metricValue; + } + + public int getMaxDegree() { + return maxDegree; + } + + public int getBuildListSize() { + return buildListSize; + } + + public int getCount() { + return count; + } + + public int getStartId() { + return startId; + } + + // ---- On-demand neighbor reading (called by Rust JNI during beam search) ---- + + /** + * Read the neighbor list of the given internal node. + * + *

Called by the Rust JNI layer during DiskANN's native beam search. Returns a defensive + * copy — callers may freely modify the returned array without corrupting the cache. + * + * @param internalNodeId the internal (graph) node ID + * @return the neighbor internal IDs, or an empty array if the node is unknown + */ + public int[] readNeighbors(int internalNodeId) { + // 1. LRU cache hit — return a defensive copy. + int[] cached = cache.get(internalNodeId); + if (cached != null) { + return Arrays.copyOf(cached, cached.length); + } + + // 2. Look up file offset. + Long offset = nodeNeighborOffsets.get(internalNodeId); + if (offset == null) { + return new int[0]; // unknown node + } + + // 3. Seek & read neighbor_cnt + neighbor_ids. + try { + input.seek(offset); + + // Read neighbor count (4 bytes). + byte[] cntBuf = new byte[4]; + readFully(input, cntBuf); + int neighborCount = readInt(cntBuf, 0); + + // Read neighbor IDs. + byte[] neighborBuf = new byte[neighborCount * 4]; + readFully(input, neighborBuf); + + int[] neighbors = new int[neighborCount]; + ByteBuffer bb = ByteBuffer.wrap(neighborBuf).order(ByteOrder.nativeOrder()); + bb.asIntBuffer().get(neighbors); + + // 4. Cache a copy. + cache.put(internalNodeId, Arrays.copyOf(neighbors, neighbors.length)); + return neighbors; + } catch (IOException e) { + throw new RuntimeException( + "Failed to read neighbors for node " + internalNodeId + " at offset " + offset, + e); + } + } + + @Override + public void close() throws IOException { + cache.clear(); + input.close(); + } + + // ---- Helpers ---- + + private static void readFully(SeekableInputStream in, byte[] buf) throws IOException { + int off = 0; + while (off < buf.length) { + int n = in.read(buf, off, buf.length - off); + if (n < 0) { + throw new IOException( + "Unexpected end of stream at offset " + off + " of " + buf.length); + } + off += n; + } + } + + private static int readInt(byte[] buf, int off) { + return (buf[off] & 0xFF) + | ((buf[off + 1] & 0xFF) << 8) + | ((buf[off + 2] & 0xFF) << 16) + | ((buf[off + 3] & 0xFF) << 24); + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/FileIOVectorReader.java b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/FileIOVectorReader.java new file mode 100644 index 000000000000..110e47c20cef --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/java/org/apache/paimon/diskann/index/FileIOVectorReader.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.fs.SeekableInputStream; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * Fetches vectors from a DiskANN data file through a Paimon {@link SeekableInputStream}. + * + *

The underlying stream can be backed by any Paimon FileIO provider — local, HDFS, S3, OSS, etc. + * + *

The Rust JNI layer uses two access modes: + * + *

    + *
  • Single-vector zero-copy: {@link #loadVector(long)} reads a vector into a + * pre-allocated {@link ByteBuffer#allocateDirect DirectByteBuffer}. The Rust side reads + * floats directly from the native memory address — no {@code float[]} allocation and no JNI + * array copy. + *
  • Batch prefetch: {@link #readVectorsBatch(long[], int)} reads multiple vectors into a + * larger DirectByteBuffer in a single JNI call, reducing per-vector JNI round-trip overhead. + *
+ * + *

Data file layout

+ * + *

Vectors are stored contiguously in sequential order. Each vector occupies {@code dimension * + * 4} bytes (native-order floats). The vector at position {@code i} is at byte offset {@code i * + * dimension * 4}. The sequential position IS the ID. + * + *

The start point vector is NOT stored in the data file; it is handled in memory by the Rust + * native layer. + */ +public class FileIOVectorReader implements Closeable { + + /** Source stream — must support seek(). */ + private final SeekableInputStream input; + + /** Vector dimension. */ + private final int dimension; + + /** Byte size of a single vector: {@code dimension * Float.BYTES}. */ + private final int vectorBytes; + + /** Reusable heap byte buffer for stream I/O (stream API requires {@code byte[]}). */ + private final byte[] readBuf; + + /** + * Pre-allocated DirectByteBuffer for single-vector reads. Rust reads directly from its native + * address via {@code GetDirectBufferAddress} — zero JNI array copy. + */ + private final ByteBuffer directBuf; + + /** + * Pre-allocated DirectByteBuffer for batch reads. Holds up to {@code maxBatchSize} vectors + * packed sequentially. + */ + private final ByteBuffer batchBuf; + + /** Maximum number of vectors that fit in {@link #batchBuf}. */ + private final int maxBatchSize; + + /** + * Create a reader. + * + * @param input seekable input stream for the data file + * @param dimension vector dimension + * @param maxBatchSize maximum number of vectors in a batch read (typically max_degree) + */ + public FileIOVectorReader(SeekableInputStream input, int dimension, int maxBatchSize) { + this.input = input; + this.dimension = dimension; + this.vectorBytes = dimension * Float.BYTES; + this.readBuf = new byte[vectorBytes]; + this.maxBatchSize = Math.max(maxBatchSize, 1); + + // Single-vector DirectByteBuffer — Rust gets its native address once at init. + this.directBuf = ByteBuffer.allocateDirect(vectorBytes).order(ByteOrder.nativeOrder()); + + // Batch DirectByteBuffer — sized for maxBatchSize vectors. + this.batchBuf = + ByteBuffer.allocateDirect(this.maxBatchSize * vectorBytes) + .order(ByteOrder.nativeOrder()); + } + + // ------------------------------------------------------------------ + // DirectByteBuffer accessors (called by Rust JNI during init) + // ------------------------------------------------------------------ + + /** Return the single-vector DirectByteBuffer. Rust caches its native address. */ + public ByteBuffer getDirectBuffer() { + return directBuf; + } + + /** Return the batch DirectByteBuffer. Rust caches its native address. */ + public ByteBuffer getBatchBuffer() { + return batchBuf; + } + + /** Return the maximum batch size supported by {@link #batchBuf}. */ + public int getMaxBatchSize() { + return maxBatchSize; + } + + // ------------------------------------------------------------------ + // Single-vector zero-copy read (hot path during beam search) + // ------------------------------------------------------------------ + + /** + * Read a single vector into the pre-allocated {@link #directBuf}. + * + *

After this call returns {@code true}, the vector data is available in the DirectByteBuffer + * at offset 0. The Rust side reads floats directly from the native memory address. + * + * @param position 0-based position in the data file (int_id − 1) + * @return {@code true} if the vector was read successfully, {@code false} if position is + * invalid + */ + public boolean loadVector(long position) { + if (position < 0) { + return false; + } + long byteOffset = position * vectorBytes; + try { + input.seek(byteOffset); + readFully(input, readBuf); + } catch (IOException e) { + throw new RuntimeException( + "Failed to read vector at position " + position + " offset " + byteOffset, e); + } + // Copy from heap byte[] into DirectByteBuffer (single memcpy, no float[] allocation). + directBuf.clear(); + directBuf.put(readBuf, 0, vectorBytes); + return true; + } + + // ------------------------------------------------------------------ + // Batch prefetch (reduces JNI call count) + // ------------------------------------------------------------------ + + /** + * Read multiple vectors into the batch DirectByteBuffer in one JNI call. + * + *

Vectors are packed sequentially in the batch buffer: vector i occupies bytes {@code [i * + * vectorBytes, (i+1) * vectorBytes)}. The Rust side reads all vectors from the native address + * after a single JNI round-trip. + * + * @param positions array of 0-based positions (int_id − 1 for each vector) + * @param count number of positions to read (must be ≤ {@link #maxBatchSize}) + * @return number of vectors successfully read (always equals {@code count} on success) + */ + public int readVectorsBatch(long[] positions, int count) { + int n = Math.min(count, maxBatchSize); + batchBuf.clear(); + for (int i = 0; i < n; i++) { + long byteOffset = positions[i] * vectorBytes; + try { + input.seek(byteOffset); + readFully(input, readBuf); + batchBuf.put(readBuf, 0, vectorBytes); + } catch (IOException e) { + throw new RuntimeException( + "Failed to batch-read vector at position " + + positions[i] + + " offset " + + byteOffset, + e); + } + } + return n; + } + + // ------------------------------------------------------------------ + // Legacy read (kept for backward compatibility) + // ------------------------------------------------------------------ + + /** + * Read a vector and return as {@code float[]}. This is the legacy path — prefer {@link + * #loadVector(long)} for the zero-copy hot path. + * + * @param position 0-based position in the data file + * @return the float vector, or {@code null} if position is negative + */ + public float[] readVector(long position) { + if (position < 0) { + return null; + } + long byteOffset = position * vectorBytes; + try { + input.seek(byteOffset); + readFully(input, readBuf); + } catch (IOException e) { + throw new RuntimeException( + "Failed to read vector at position " + position + " offset " + byteOffset, e); + } + float[] vector = new float[dimension]; + ByteBuffer bb = ByteBuffer.wrap(readBuf).order(ByteOrder.nativeOrder()); + bb.asFloatBuffer().get(vector); + return vector; + } + + @Override + public void close() throws IOException { + input.close(); + } + + // ------------------------------------------------------------------ + // Helpers + // ------------------------------------------------------------------ + + private static void readFully(SeekableInputStream in, byte[] buf) throws IOException { + int off = 0; + while (off < buf.length) { + int n = in.read(buf, off, buf.length - off); + if (n < 0) { + throw new IOException( + "Unexpected end of stream at offset " + off + " of " + buf.length); + } + off += n; + } + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/main/resources/META-INF/services/org.apache.paimon.globalindex.GlobalIndexerFactory b/paimon-diskann/paimon-diskann-index/src/main/resources/META-INF/services/org.apache.paimon.globalindex.GlobalIndexerFactory new file mode 100644 index 000000000000..9906dcfa8e46 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/main/resources/META-INF/services/org.apache.paimon.globalindex.GlobalIndexerFactory @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.paimon.diskann.index.DiskAnnVectorGlobalIndexerFactory diff --git a/paimon-diskann/paimon-diskann-index/src/test/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexScanTest.java b/paimon-diskann/paimon-diskann-index/src/test/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexScanTest.java new file mode 100644 index 000000000000..043115e4bcb5 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/test/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexScanTest.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.CoreOptions; +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.GenericArray; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.diskann.DiskAnn; +import org.apache.paimon.diskann.DiskAnnException; +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.PositionOutputStream; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.globalindex.ResultEntry; +import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.index.GlobalIndexMeta; +import org.apache.paimon.index.IndexFileMeta; +import org.apache.paimon.io.CompactIncrement; +import org.apache.paimon.io.DataIncrement; +import org.apache.paimon.options.Options; +import org.apache.paimon.predicate.VectorSearch; +import org.apache.paimon.schema.Schema; +import org.apache.paimon.schema.SchemaManager; +import org.apache.paimon.schema.TableSchema; +import org.apache.paimon.table.FileStoreTable; +import org.apache.paimon.table.FileStoreTableFactory; +import org.apache.paimon.table.sink.CommitMessage; +import org.apache.paimon.table.sink.CommitMessageImpl; +import org.apache.paimon.table.sink.StreamTableCommit; +import org.apache.paimon.table.sink.StreamTableWrite; +import org.apache.paimon.table.source.ReadBuilder; +import org.apache.paimon.table.source.TableScan; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for scanning DiskANN vector global index. */ +public class DiskAnnVectorGlobalIndexScanTest { + + @TempDir java.nio.file.Path tempDir; + + private FileStoreTable table; + private String commitUser; + private FileIO fileIO; + private RowType rowType; + private final String vectorFieldName = "vec"; + + @BeforeEach + public void before() throws Exception { + // Skip tests if DiskANN native library is not available + if (!DiskAnn.isLibraryLoaded()) { + try { + DiskAnn.loadLibrary(); + } catch (DiskAnnException e) { + StringBuilder errorMsg = new StringBuilder("DiskANN native library not available."); + errorMsg.append("\nError: ").append(e.getMessage()); + if (e.getCause() != null) { + errorMsg.append("\nCause: ").append(e.getCause().getMessage()); + } + errorMsg.append( + "\n\nTo run DiskANN tests, ensure the paimon-diskann-jni JAR" + + " with native libraries is available in the classpath."); + Assumptions.assumeTrue(false, errorMsg.toString()); + } + } + + Path tablePath = new Path(tempDir.toString()); + fileIO = new LocalFileIO(); + SchemaManager schemaManager = new SchemaManager(fileIO, tablePath); + + Schema schema = + Schema.newBuilder() + .column("id", DataTypes.INT()) + .column(vectorFieldName, new ArrayType(DataTypes.FLOAT())) + .option(CoreOptions.BUCKET.key(), "-1") + .option("vector.dim", "2") + .option("vector.metric", "L2") + .option("data-evolution.enabled", "true") + .option("row-tracking.enabled", "true") + .build(); + + TableSchema tableSchema = schemaManager.createTable(schema); + table = FileStoreTableFactory.create(fileIO, tablePath, tableSchema); + rowType = table.rowType(); + commitUser = UUID.randomUUID().toString(); + } + + @Test + public void testVectorIndexScanEndToEnd() throws Exception { + float[][] vectors = + new float[][] { + new float[] {1.0f, 0.0f}, new float[] {0.95f, 0.1f}, new float[] {0.1f, 0.95f}, + new float[] {0.98f, 0.05f}, new float[] {0.0f, 1.0f}, new float[] {0.05f, 0.98f} + }; + + writeVectors(vectors); + + List indexFiles = buildIndexManually(vectors); + + commitIndex(indexFiles); + + float[] queryVector = new float[] {0.85f, 0.15f}; + VectorSearch vectorSearch = new VectorSearch(queryVector, 2, vectorFieldName); + ReadBuilder readBuilder = table.newReadBuilder().withVectorSearch(vectorSearch); + TableScan scan = readBuilder.newScan(); + List ids = new ArrayList<>(); + readBuilder + .newRead() + .createReader(scan.plan()) + .forEachRemaining( + row -> { + ids.add(row.getInt(0)); + }); + // With L2 distance, the closest vectors to [0.85, 0.15] should be [0.95, 0.1] and [0.98, + // 0.05] + assertThat(ids).containsExactlyInAnyOrder(1, 3); + } + + @Test + public void testVectorIndexScanWithDifferentMetrics() throws Exception { + Path tablePath = new Path(tempDir.toString(), "inner_product"); + fileIO.mkdirs(tablePath); + SchemaManager schemaManager = new SchemaManager(fileIO, tablePath); + + Schema schema = + Schema.newBuilder() + .column("id", DataTypes.INT()) + .column(vectorFieldName, new ArrayType(DataTypes.FLOAT())) + .option(CoreOptions.BUCKET.key(), "-1") + .option("vector.dim", "2") + .option("vector.metric", "INNER_PRODUCT") + .option("data-evolution.enabled", "true") + .option("row-tracking.enabled", "true") + .build(); + + TableSchema tableSchema = schemaManager.createTable(schema); + FileStoreTable ipTable = FileStoreTableFactory.create(fileIO, tablePath, tableSchema); + String ipCommitUser = UUID.randomUUID().toString(); + + float[][] vectors = + new float[][] { + new float[] {1.0f, 0.0f}, + new float[] {0.707f, 0.707f}, + new float[] {0.0f, 1.0f}, + }; + + StreamTableWrite write = ipTable.newWrite(ipCommitUser); + for (int i = 0; i < vectors.length; i++) { + write.write(GenericRow.of(i, new GenericArray(vectors[i]))); + } + List messages = write.prepareCommit(false, 0); + StreamTableCommit commit = ipTable.newCommit(ipCommitUser); + commit.commit(0, messages); + write.close(); + + Options options = new Options(ipTable.options()); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + Path indexDir = ipTable.store().pathFactory().indexPath(); + if (!fileIO.exists(indexDir)) { + fileIO.mkdirs(indexDir); + } + + GlobalIndexFileWriter fileWriter = + new GlobalIndexFileWriter() { + @Override + public String newFileName(String prefix) { + return prefix + "-" + UUID.randomUUID(); + } + + @Override + public PositionOutputStream newOutputStream(String fileName) + throws IOException { + return fileIO.newOutputStream(new Path(indexDir, fileName), false); + } + }; + + DiskAnnVectorGlobalIndexWriter indexWriter = + new DiskAnnVectorGlobalIndexWriter( + fileWriter, new ArrayType(DataTypes.FLOAT()), indexOptions); + for (float[] vec : vectors) { + indexWriter.write(vec); + } + + List entries = indexWriter.finish(); + List metas = new ArrayList<>(); + int fieldId = ipTable.rowType().getFieldIndex(vectorFieldName); + + for (ResultEntry entry : entries) { + long fileSize = fileIO.getFileSize(new Path(indexDir, entry.fileName())); + GlobalIndexMeta globalMeta = + new GlobalIndexMeta(0, vectors.length - 1, fieldId, null, entry.meta()); + + metas.add( + new IndexFileMeta( + DiskAnnVectorGlobalIndexerFactory.IDENTIFIER, + entry.fileName(), + fileSize, + entry.rowCount(), + globalMeta, + (String) null)); + } + + DataIncrement dataIncrement = DataIncrement.indexIncrement(metas); + CommitMessage message = + new CommitMessageImpl( + BinaryRow.EMPTY_ROW, + 0, + 1, + dataIncrement, + CompactIncrement.emptyIncrement()); + ipTable.newCommit(ipCommitUser).commit(1, Collections.singletonList(message)); + + float[] queryVector = new float[] {1.0f, 0.0f}; + VectorSearch vectorSearch = new VectorSearch(queryVector, 1, vectorFieldName); + ReadBuilder readBuilder = ipTable.newReadBuilder().withVectorSearch(vectorSearch); + TableScan scan = readBuilder.newScan(); + List ids = new ArrayList<>(); + readBuilder + .newRead() + .createReader(scan.plan()) + .forEachRemaining( + row -> { + ids.add(row.getInt(0)); + }); + assertThat(ids).containsExactly(0); + } + + private void writeVectors(float[][] vectors) throws Exception { + StreamTableWrite write = table.newWrite(commitUser); + for (int i = 0; i < vectors.length; i++) { + write.write(GenericRow.of(i, new GenericArray(vectors[i]))); + } + List messages = write.prepareCommit(false, 0); + StreamTableCommit commit = table.newCommit(commitUser); + commit.commit(0, messages); + write.close(); + } + + private List buildIndexManually(float[][] vectors) throws Exception { + Options options = new Options(table.options()); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + Path indexDir = table.store().pathFactory().indexPath(); + if (!fileIO.exists(indexDir)) { + fileIO.mkdirs(indexDir); + } + + GlobalIndexFileWriter fileWriter = + new GlobalIndexFileWriter() { + @Override + public String newFileName(String prefix) { + return prefix + "-" + UUID.randomUUID(); + } + + @Override + public PositionOutputStream newOutputStream(String fileName) + throws IOException { + return fileIO.newOutputStream(new Path(indexDir, fileName), false); + } + }; + + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter( + fileWriter, new ArrayType(DataTypes.FLOAT()), indexOptions); + for (float[] vec : vectors) { + writer.write(vec); + } + + List entries = writer.finish(); + + List metas = new ArrayList<>(); + int fieldId = rowType.getFieldIndex(vectorFieldName); + + for (ResultEntry entry : entries) { + long fileSize = fileIO.getFileSize(new Path(indexDir, entry.fileName())); + GlobalIndexMeta globalMeta = + new GlobalIndexMeta(0, vectors.length - 1, fieldId, null, entry.meta()); + + metas.add( + new IndexFileMeta( + DiskAnnVectorGlobalIndexerFactory.IDENTIFIER, + entry.fileName(), + fileSize, + entry.rowCount(), + globalMeta, + (String) null)); + } + return metas; + } + + private void commitIndex(List indexFiles) { + StreamTableCommit commit = table.newCommit(commitUser); + DataIncrement dataIncrement = DataIncrement.indexIncrement(indexFiles); + CommitMessage message = + new CommitMessageImpl( + BinaryRow.EMPTY_ROW, + 0, + 1, + dataIncrement, + CompactIncrement.emptyIncrement()); + commit.commit(1, Collections.singletonList(message)); + } +} diff --git a/paimon-diskann/paimon-diskann-index/src/test/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexTest.java b/paimon-diskann/paimon-diskann-index/src/test/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexTest.java new file mode 100644 index 000000000000..3a9d66c72a08 --- /dev/null +++ b/paimon-diskann/paimon-diskann-index/src/test/java/org/apache/paimon/diskann/index/DiskAnnVectorGlobalIndexTest.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann.index; + +import org.apache.paimon.diskann.DiskAnn; +import org.apache.paimon.diskann.DiskAnnException; +import org.apache.paimon.fs.FileIO; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.PositionOutputStream; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.globalindex.GlobalIndexIOMeta; +import org.apache.paimon.globalindex.GlobalIndexResult; +import org.apache.paimon.globalindex.ResultEntry; +import org.apache.paimon.globalindex.io.GlobalIndexFileReader; +import org.apache.paimon.globalindex.io.GlobalIndexFileWriter; +import org.apache.paimon.options.Options; +import org.apache.paimon.predicate.VectorSearch; +import org.apache.paimon.types.ArrayType; +import org.apache.paimon.types.DataType; +import org.apache.paimon.types.FloatType; +import org.apache.paimon.utils.RoaringNavigableMap64; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link DiskAnnVectorGlobalIndexWriter} and {@link DiskAnnVectorGlobalIndexReader}. */ +public class DiskAnnVectorGlobalIndexTest { + + @TempDir java.nio.file.Path tempDir; + + private FileIO fileIO; + private Path indexPath; + private DataType vectorType; + private final String fieldName = "vec"; + + @BeforeEach + public void setup() { + // Skip tests if DiskANN native library is not available + if (!DiskAnn.isLibraryLoaded()) { + try { + DiskAnn.loadLibrary(); + } catch (DiskAnnException e) { + StringBuilder errorMsg = new StringBuilder("DiskANN native library not available."); + errorMsg.append("\nError: ").append(e.getMessage()); + if (e.getCause() != null) { + errorMsg.append("\nCause: ").append(e.getCause().getMessage()); + } + errorMsg.append( + "\n\nTo run DiskANN tests, ensure the paimon-diskann-jni JAR" + + " with native libraries is available in the classpath."); + Assumptions.assumeTrue(false, errorMsg.toString()); + } + } + + fileIO = new LocalFileIO(); + indexPath = new Path(tempDir.toString()); + vectorType = new ArrayType(new FloatType()); + } + + @AfterEach + public void cleanup() throws IOException { + if (fileIO != null) { + fileIO.delete(indexPath, true); + } + } + + private GlobalIndexFileWriter createFileWriter(Path path) { + return new GlobalIndexFileWriter() { + @Override + public String newFileName(String prefix) { + return prefix + "-" + UUID.randomUUID(); + } + + @Override + public PositionOutputStream newOutputStream(String fileName) throws IOException { + return fileIO.newOutputStream(new Path(path, fileName), false); + } + }; + } + + private GlobalIndexFileReader createFileReader(Path basePath) { + return meta -> fileIO.newInputStream(new Path(basePath, meta.filePath())); + } + + @Test + public void testDifferentMetrics() throws IOException { + int dimension = 32; + int numVectors = 20; + + String[] metrics = {"L2", "INNER_PRODUCT", "COSINE"}; + + for (String metric : metrics) { + Options options = createDefaultOptions(dimension); + options.setString("vector.metric", metric); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + Path metricIndexPath = new Path(indexPath, metric.toLowerCase()); + GlobalIndexFileWriter fileWriter = createFileWriter(metricIndexPath); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + List testVectors = generateRandomVectors(numVectors, dimension); + testVectors.forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(1); + + ResultEntry result = results.get(0); + GlobalIndexFileReader fileReader = createFileReader(metricIndexPath); + List metas = new ArrayList<>(); + metas.add( + new GlobalIndexIOMeta( + new Path(metricIndexPath, result.fileName()), + fileIO.getFileSize(new Path(metricIndexPath, result.fileName())), + result.meta())); + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(testVectors.get(0), 3, fieldName); + GlobalIndexResult searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + } + } + } + + @Test + public void testDefaultOptions() throws IOException { + int dimension = 32; + int numVectors = 100; + + Options options = createDefaultOptions(dimension); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + List testVectors = generateRandomVectors(numVectors, dimension); + testVectors.forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(1); + + ResultEntry result = results.get(0); + GlobalIndexFileReader fileReader = createFileReader(indexPath); + List metas = new ArrayList<>(); + metas.add( + new GlobalIndexIOMeta( + new Path(indexPath, result.fileName()), + fileIO.getFileSize(new Path(indexPath, result.fileName())), + result.meta())); + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(testVectors.get(0), 5, fieldName); + GlobalIndexResult searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + } + } + + @Test + public void testDifferentDimensions() throws IOException { + int[] dimensions = {8, 32, 128, 256}; + + for (int dimension : dimensions) { + Options options = createDefaultOptions(dimension); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + Path dimIndexPath = new Path(indexPath, "dim_" + dimension); + GlobalIndexFileWriter fileWriter = createFileWriter(dimIndexPath); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + int numVectors = 10; + List testVectors = generateRandomVectors(numVectors, dimension); + testVectors.forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(1); + + ResultEntry result = results.get(0); + GlobalIndexFileReader fileReader = createFileReader(dimIndexPath); + List metas = new ArrayList<>(); + metas.add( + new GlobalIndexIOMeta( + new Path(dimIndexPath, result.fileName()), + fileIO.getFileSize(new Path(dimIndexPath, result.fileName())), + result.meta())); + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader( + fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(testVectors.get(0), 5, fieldName); + GlobalIndexResult searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + } + } + } + + @Test + public void testDimensionMismatch() throws IOException { + Options options = createDefaultOptions(64); + + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + float[] wrongDimVector = new float[32]; + assertThatThrownBy(() -> writer.write(wrongDimVector)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("dimension mismatch"); + } + + @Test + public void testFloatVectorIndexEndToEnd() throws IOException { + int dimension = 2; + Options options = createDefaultOptions(dimension); + int sizePerIndex = 3; + options.setInteger("vector.size-per-index", sizePerIndex); + + float[][] vectors = + new float[][] { + new float[] {1.0f, 0.0f}, new float[] {0.95f, 0.1f}, new float[] {0.1f, 0.95f}, + new float[] {0.98f, 0.05f}, new float[] {0.0f, 1.0f}, new float[] {0.05f, 0.98f} + }; + + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + Arrays.stream(vectors).forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(2); + + GlobalIndexFileReader fileReader = createFileReader(indexPath); + List metas = new ArrayList<>(); + for (ResultEntry result : results) { + metas.add( + new GlobalIndexIOMeta( + new Path(indexPath, result.fileName()), + fileIO.getFileSize(new Path(indexPath, result.fileName())), + result.meta())); + } + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(vectors[0], 1, fieldName); + DiskAnnScoredGlobalIndexResult result = + (DiskAnnScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + assertThat(result.results().getLongCardinality()).isEqualTo(1); + long expectedRowId = 0; + assertThat(containsRowId(result, expectedRowId)).isTrue(); + + // Test with filter + expectedRowId = 1; + RoaringNavigableMap64 filterResults = new RoaringNavigableMap64(); + filterResults.add(expectedRowId); + vectorSearch = + new VectorSearch(vectors[0], 1, fieldName).withIncludeRowIds(filterResults); + result = (DiskAnnScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + assertThat(containsRowId(result, expectedRowId)).isTrue(); + + // Test with multiple results + float[] queryVector = new float[] {0.85f, 0.15f}; + vectorSearch = new VectorSearch(queryVector, 2, fieldName); + result = (DiskAnnScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + assertThat(result.results().getLongCardinality()).isEqualTo(2); + } + } + + @Test + public void testInvalidTopK() { + assertThatThrownBy(() -> new VectorSearch(new float[] {0.1f}, 0, fieldName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Limit must be positive"); + } + + @Test + public void testMultipleIndexFiles() throws IOException { + int dimension = 32; + Options options = createDefaultOptions(dimension); + options.setInteger("vector.size-per-index", 5); + + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + int numVectors = 15; + List testVectors = generateRandomVectors(numVectors, dimension); + testVectors.forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(3); + + GlobalIndexFileReader fileReader = createFileReader(indexPath); + List metas = new ArrayList<>(); + for (ResultEntry result : results) { + metas.add( + new GlobalIndexIOMeta( + new Path(indexPath, result.fileName()), + fileIO.getFileSize(new Path(indexPath, result.fileName())), + result.meta())); + } + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(testVectors.get(10), 3, fieldName); + GlobalIndexResult searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + assertThat(searchResult.results().getLongCardinality()).isGreaterThan(0); + } + } + + @Test + public void testBatchWriteMultipleFiles() throws IOException { + int dimension = 8; + Options options = createDefaultOptions(dimension); + int sizePerIndex = 100; + options.setInteger("vector.size-per-index", sizePerIndex); + + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + int numVectors = 350; + List testVectors = generateRandomVectors(numVectors, dimension); + testVectors.forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(4); + + GlobalIndexFileReader fileReader = createFileReader(indexPath); + List metas = new ArrayList<>(); + for (ResultEntry result : results) { + Path filePath = new Path(indexPath, result.fileName()); + assertThat(fileIO.exists(filePath)).isTrue(); + assertThat(fileIO.getFileSize(filePath)).isGreaterThan(0); + metas.add(new GlobalIndexIOMeta(filePath, fileIO.getFileSize(filePath), result.meta())); + } + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(testVectors.get(50), 3, fieldName); + GlobalIndexResult searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + assertThat(searchResult.results().getLongCardinality()).isGreaterThan(0); + + vectorSearch = new VectorSearch(testVectors.get(150), 3, fieldName); + searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + assertThat(searchResult.results().getLongCardinality()).isGreaterThan(0); + + vectorSearch = new VectorSearch(testVectors.get(320), 3, fieldName); + searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + assertThat(searchResult.results().getLongCardinality()).isGreaterThan(0); + + vectorSearch = new VectorSearch(testVectors.get(200), 1, fieldName); + DiskAnnScoredGlobalIndexResult result = + (DiskAnnScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + assertThat(containsRowId(result, 200)).isTrue(); + } + } + + @Test + public void testBatchWriteWithRemainder() throws IOException { + int dimension = 16; + Options options = createDefaultOptions(dimension); + int sizePerIndex = 50; + options.setInteger("vector.size-per-index", sizePerIndex); + + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + int numVectors = 73; + List testVectors = generateRandomVectors(numVectors, dimension); + testVectors.forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(2); + + GlobalIndexFileReader fileReader = createFileReader(indexPath); + List metas = new ArrayList<>(); + for (ResultEntry result : results) { + metas.add( + new GlobalIndexIOMeta( + new Path(indexPath, result.fileName()), + fileIO.getFileSize(new Path(indexPath, result.fileName())), + result.meta())); + } + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(testVectors.get(60), 1, fieldName); + DiskAnnScoredGlobalIndexResult result = + (DiskAnnScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + assertThat(result).isNotNull(); + assertThat(containsRowId(result, 60)).isTrue(); + + vectorSearch = new VectorSearch(testVectors.get(72), 1, fieldName); + result = (DiskAnnScoredGlobalIndexResult) reader.visitVectorSearch(vectorSearch).get(); + assertThat(result).isNotNull(); + assertThat(containsRowId(result, 72)).isTrue(); + } + } + + @Test + public void testPqFilesProducedWithCorrectStructure() throws IOException { + int dimension = 8; + int numVectors = 50; + + Options options = createDefaultOptions(dimension); + DiskAnnVectorIndexOptions indexOptions = new DiskAnnVectorIndexOptions(options); + GlobalIndexFileWriter fileWriter = createFileWriter(indexPath); + DiskAnnVectorGlobalIndexWriter writer = + new DiskAnnVectorGlobalIndexWriter(fileWriter, vectorType, indexOptions); + + List testVectors = generateRandomVectors(numVectors, dimension); + testVectors.forEach(writer::write); + + List results = writer.finish(); + assertThat(results).hasSize(1); + + ResultEntry result = results.get(0); + DiskAnnIndexMeta meta = DiskAnnIndexMeta.deserialize(result.meta()); + + // Verify all four files exist. + Path indexFilePath = new Path(indexPath, result.fileName()); + Path dataFilePath = new Path(indexPath, meta.dataFileName()); + Path pqPivotsPath = new Path(indexPath, meta.pqPivotsFileName()); + Path pqCompressedPath = new Path(indexPath, meta.pqCompressedFileName()); + + assertThat(fileIO.exists(indexFilePath)).as("Index file should exist").isTrue(); + assertThat(fileIO.exists(dataFilePath)).as("Data file should exist").isTrue(); + assertThat(fileIO.exists(pqPivotsPath)).as("PQ pivots file should exist").isTrue(); + assertThat(fileIO.exists(pqCompressedPath)).as("PQ compressed file should exist").isTrue(); + + // Verify PQ pivots header: dim(i32), M(i32), K(i32), subDim(i32) + byte[] pivotsData = readAllBytesFromFile(pqPivotsPath); + assertThat(pivotsData.length).as("PQ pivots should have data").isGreaterThan(16); + + java.nio.ByteBuffer pivotsBuf = + java.nio.ByteBuffer.wrap(pivotsData).order(java.nio.ByteOrder.nativeOrder()); + int readDim = pivotsBuf.getInt(); + int readM = pivotsBuf.getInt(); + int readK = pivotsBuf.getInt(); + int readSubDim = pivotsBuf.getInt(); + + assertThat(readDim).as("PQ pivots dimension").isEqualTo(dimension); + int expectedM = indexOptions.pqSubspaces(); + assertThat(readM).as("PQ pivots num_subspaces").isEqualTo(expectedM); + assertThat(readK).as("PQ pivots num_centroids").isGreaterThan(0).isLessThanOrEqualTo(256); + assertThat(readSubDim).as("PQ pivots sub_dimension").isEqualTo(dimension / expectedM); + + // Verify total pivots file size: 16 header + M * K * subDim * 4 + int expectedPivotsSize = 16 + readM * readK * readSubDim * 4; + assertThat(pivotsData.length).as("PQ pivots file size").isEqualTo(expectedPivotsSize); + + // Verify PQ compressed header: N(i32), M(i32), then N*M bytes of codes + byte[] compressedData = readAllBytesFromFile(pqCompressedPath); + assertThat(compressedData.length).as("PQ compressed should have data").isGreaterThan(8); + + java.nio.ByteBuffer compBuf = + java.nio.ByteBuffer.wrap(compressedData).order(java.nio.ByteOrder.nativeOrder()); + int readN = compBuf.getInt(); + int readCompM = compBuf.getInt(); + + assertThat(readN).as("PQ compressed num_vectors").isEqualTo(numVectors); + assertThat(readCompM).as("PQ compressed num_subspaces").isEqualTo(expectedM); + + // Verify total compressed file size: 8 header + N * M + int expectedCompSize = 8 + readN * readCompM; + assertThat(compressedData.length).as("PQ compressed file size").isEqualTo(expectedCompSize); + + // Verify search still works with these PQ files. + GlobalIndexFileReader fileReader = createFileReader(indexPath); + List metas = new ArrayList<>(); + metas.add( + new GlobalIndexIOMeta( + indexFilePath, fileIO.getFileSize(indexFilePath), result.meta())); + + try (DiskAnnVectorGlobalIndexReader reader = + new DiskAnnVectorGlobalIndexReader(fileReader, metas, vectorType, indexOptions)) { + VectorSearch vectorSearch = new VectorSearch(testVectors.get(0), 3, fieldName); + GlobalIndexResult searchResult = reader.visitVectorSearch(vectorSearch).get(); + assertThat(searchResult).isNotNull(); + assertThat(searchResult.results().getLongCardinality()).isGreaterThan(0); + } + } + + private Options createDefaultOptions(int dimension) { + Options options = new Options(); + options.setInteger("vector.dim", dimension); + options.setString("vector.metric", "L2"); + options.setInteger("vector.diskann.max-degree", 64); + options.setInteger("vector.diskann.build-list-size", 100); + options.setInteger("vector.diskann.search-list-size", 100); + return options; + } + + private List generateRandomVectors(int count, int dimension) { + Random random = new Random(42); + List vectors = new ArrayList<>(); + for (int i = 0; i < count; i++) { + float[] vector = new float[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = random.nextFloat() * 2 - 1; + } + float norm = 0; + for (float v : vector) { + norm += v * v; + } + norm = (float) Math.sqrt(norm); + if (norm > 0) { + for (int m = 0; m < vector.length; m++) { + vector[m] /= norm; + } + } + vectors.add(vector); + } + return vectors; + } + + private byte[] readAllBytesFromFile(Path path) throws IOException { + int fileSize = (int) fileIO.getFileSize(path); + byte[] data = new byte[fileSize]; + try (java.io.InputStream in = fileIO.newInputStream(path)) { + int offset = 0; + while (offset < fileSize) { + int read = in.read(data, offset, fileSize - offset); + if (read < 0) { + break; + } + offset += read; + } + } + return data; + } + + private boolean containsRowId(GlobalIndexResult result, long rowId) { + List resultIds = new ArrayList<>(); + result.results().iterator().forEachRemaining(resultIds::add); + return resultIds.contains(rowId); + } +} diff --git a/paimon-diskann/paimon-diskann-jni/pom.xml b/paimon-diskann/paimon-diskann-jni/pom.xml new file mode 100644 index 000000000000..630f31e3b67e --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/pom.xml @@ -0,0 +1,104 @@ + + + + 4.0.0 + + + paimon-diskann + org.apache.paimon + 1.4-SNAPSHOT + + + paimon-diskann-jni + Paimon : DiskANN JNI + + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + + + + + + org.apache.maven.plugins + maven-resources-plugin + + + so + + so.* + dylib + + + + + + + + + + release + + + + + org.apache.maven.plugins + maven-resources-plugin + + + copy-native-libs + prepare-package + + copy-resources + + + ${project.build.outputDirectory} + + + ${project.basedir}/src/main/resources + + **/*.so + + **/*.so.* + **/*.dylib + + + + + + + + + + + + diff --git a/paimon-diskann/paimon-diskann-jni/scripts/build-native.sh b/paimon-diskann/paimon-diskann-jni/scripts/build-native.sh new file mode 100755 index 000000000000..61763a8129a7 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/scripts/build-native.sh @@ -0,0 +1,427 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +NATIVE_DIR="$PROJECT_DIR/src/main/native" +BUILD_DIR="$PROJECT_DIR/build/native" + +# Parse arguments +CLEAN=false +RELEASE=true + +while [[ $# -gt 0 ]]; do + case $1 in + --clean) + CLEAN=true + shift + ;; + --debug) + RELEASE=false + shift + ;; + --help) + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --clean Clean build directory before building" + echo " --debug Build in debug mode (default: release)" + echo " --help Show this help message" + echo "" + echo "Environment variables:" + echo " RUST_TARGET Cargo target triple (e.g. aarch64-apple-darwin)" + echo " RUSTFLAGS Extra rustc flags" + echo " CARGO_FEATURES Extra cargo features (comma-separated)" + echo "" + echo "Example:" + echo " $0 --clean" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +echo "================================================" +echo "Building Paimon DiskANN JNI - Native Library" +echo "================================================" +echo "Build mode: $([ "$RELEASE" = true ] && echo release || echo debug)" +echo "" + +if [ "$CLEAN" = true ]; then + echo "Cleaning build directory..." + rm -rf "$BUILD_DIR" +fi + +mkdir -p "$BUILD_DIR" +cd "$NATIVE_DIR" + +if [ ! -f "$NATIVE_DIR/Cargo.toml" ]; then + echo "ERROR: Cargo.toml not found in $NATIVE_DIR" + echo "Place the Rust JNI crate under src/main/native." + exit 1 +fi + +# Ensure $HOME/.cargo/bin is in PATH (rustup installs here). +if [ -d "$HOME/.cargo/bin" ]; then + export PATH="$HOME/.cargo/bin:$PATH" +fi + +# ---------- Check for required tools: rustup, rustc, cargo ---------- +# The rust-toolchain.toml in the native crate directory specifies Rust 1.90.0 +# which is required by diskann-vector v0.45.0 (uses unsigned_is_multiple_of, +# stabilised in Rust 1.87). +REQUIRED_RUST_VERSION="1.90.0" + +# 1. Ensure rustup is available; install if missing. +if ! command -v rustup &> /dev/null; then + echo "" + echo "rustup not found. Installing rustup with Rust $REQUIRED_RUST_VERSION..." + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain "$REQUIRED_RUST_VERSION" + export PATH="$HOME/.cargo/bin:$PATH" +fi + +# 2. Ensure the required toolchain (and its cargo/rustc) is installed. +echo "" +echo "Ensuring the required Rust toolchain ($REQUIRED_RUST_VERSION) is installed..." +if ! rustup run "$REQUIRED_RUST_VERSION" rustc --version &> /dev/null; then + echo "Toolchain $REQUIRED_RUST_VERSION not found. Installing..." + rustup toolchain install "$REQUIRED_RUST_VERSION" --profile minimal +fi + +# 3. Verify cargo is usable with the required toolchain. +if ! rustup run "$REQUIRED_RUST_VERSION" cargo --version &> /dev/null; then + echo "cargo not usable with toolchain $REQUIRED_RUST_VERSION. Re-installing..." + rustup toolchain uninstall "$REQUIRED_RUST_VERSION" 2>/dev/null || true + rustup toolchain install "$REQUIRED_RUST_VERSION" --profile minimal +fi + +echo " rustc: $(rustup run "$REQUIRED_RUST_VERSION" rustc --version)" +echo " cargo: $(rustup run "$REQUIRED_RUST_VERSION" cargo --version)" + +# Detect platform +OS=$(uname -s) +ARCH=$(uname -m) + +echo "Detected platform: $OS $ARCH" + +# Build with Cargo +echo "" +echo "Building with Cargo..." + +CARGO_ARGS=() +if [ "$RELEASE" = true ]; then + CARGO_ARGS+=(--release) +fi +if [ -n "$RUST_TARGET" ]; then + CARGO_ARGS+=(--target "$RUST_TARGET") + echo "Using RUST_TARGET: $RUST_TARGET" +fi +if [ -n "$CARGO_FEATURES" ]; then + CARGO_ARGS+=(--features "$CARGO_FEATURES") + echo "Using CARGO_FEATURES: $CARGO_FEATURES" +fi + +cargo build "${CARGO_ARGS[@]}" + +echo "" +echo "============================================" +echo "Build completed successfully!" +echo "============================================" + +# Determine output directory based on platform +if [ "$OS" = "Linux" ]; then + PLATFORM_OS="linux" + if [ "$ARCH" = "x86_64" ] || [ "$ARCH" = "amd64" ]; then + PLATFORM_ARCH="amd64" + else + PLATFORM_ARCH="aarch64" + fi +elif [ "$OS" = "Darwin" ]; then + PLATFORM_OS="darwin" + if [ "$ARCH" = "arm64" ]; then + PLATFORM_ARCH="aarch64" + else + PLATFORM_ARCH="amd64" + fi +else + echo "Unsupported OS: $OS" + exit 1 +fi + +OUTPUT_DIR="$PROJECT_DIR/src/main/resources/$PLATFORM_OS/$PLATFORM_ARCH" +mkdir -p "$OUTPUT_DIR" + +LIB_NAME="libpaimon_diskann_jni" +BUILD_MODE_DIR="$([ "$RELEASE" = true ] && echo release || echo debug)" + +if [ -n "$RUST_TARGET" ]; then + TARGET_DIR="target/$RUST_TARGET/$BUILD_MODE_DIR" +else + TARGET_DIR="target/$BUILD_MODE_DIR" +fi + +if [ "$OS" = "Darwin" ]; then + SRC_LIB="$NATIVE_DIR/$TARGET_DIR/$LIB_NAME.dylib" +else + SRC_LIB="$NATIVE_DIR/$TARGET_DIR/$LIB_NAME.so" +fi + +if [ ! -f "$SRC_LIB" ]; then + echo "ERROR: Built library not found: $SRC_LIB" + exit 1 +fi + +cp "$SRC_LIB" "$OUTPUT_DIR/" + +# ===================================================================== +# Bundle shared library dependencies +# ===================================================================== +# Rust cdylib statically links all Rust code but may dynamically link +# to system C/C++ libraries (libgcc_s, libstdc++, etc.). On Linux CI +# containers the target machine may have different versions, so we +# bundle all non-trivial dependencies — mirroring the FAISS approach. +# ===================================================================== + +echo "" +echo "============================================" +echo "Checking & bundling library dependencies" +echo "============================================" + +if [ "$OS" = "Linux" ]; then + # ---- Helper: copy a real library file into OUTPUT_DIR ---- + bundle_lib() { + local src_path="$1" + local target_name="$2" + + if [ -f "$OUTPUT_DIR/$target_name" ]; then + echo " Already bundled: $target_name" + return 0 + fi + + # Resolve symlinks to the real file + local real_path + real_path=$(readlink -f "$src_path" 2>/dev/null || realpath "$src_path" 2>/dev/null || echo "$src_path") + if [ ! -f "$real_path" ]; then + echo " Cannot resolve: $src_path" + return 1 + fi + + cp "$real_path" "$OUTPUT_DIR/$target_name" + chmod +x "$OUTPUT_DIR/$target_name" + echo " Bundled: $real_path -> $target_name" + return 0 + } + + # ---- Helper: search common paths for a library by glob pattern ---- + find_and_bundle() { + local pattern="$1" + local target_name="$2" + + if [ -f "$OUTPUT_DIR/$target_name" ]; then + echo " Already bundled: $target_name" + return 0 + fi + + for search_path in /usr/local/lib /usr/local/lib64 \ + /usr/lib /usr/lib64 \ + /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu; do + local found_lib + found_lib=$(find "$search_path" -maxdepth 1 -name "$pattern" -type f 2>/dev/null | head -1) + if [ -n "$found_lib" ] && [ -f "$found_lib" ]; then + bundle_lib "$found_lib" "$target_name" + return $? + fi + local found_link + found_link=$(find "$search_path" -maxdepth 1 -name "$pattern" -type l 2>/dev/null | head -1) + if [ -n "$found_link" ] && [ -L "$found_link" ]; then + bundle_lib "$found_link" "$target_name" + return $? + fi + done + + # Try ldconfig cache + local ldconfig_path + ldconfig_path=$(ldconfig -p 2>/dev/null | grep "$pattern" | head -1 | awk '{print $NF}') + if [ -n "$ldconfig_path" ] && [ -f "$ldconfig_path" ]; then + bundle_lib "$ldconfig_path" "$target_name" + return $? + fi + + echo " Not found: $pattern" + return 1 + } + + echo "" + echo "Bundling required libraries..." + + # 1. GCC runtime (Rust cdylib may link against libgcc_s for stack unwinding) + if ! find_and_bundle "libgcc_s.so*" "libgcc_s.so.1"; then + echo " Note: libgcc_s not found as shared library - likely statically linked" + fi + + # 2. C++ standard library (needed if the diskann crate compiles any C++ code) + if ! find_and_bundle "libstdc++.so*" "libstdc++.so.6"; then + echo " Note: libstdc++ not found as shared library - likely statically linked" + fi + + # ---- Scan ldd for additional non-system dependencies ---- + echo "" + echo "Scanning ldd for additional dependencies..." + JNI_LIB="$OUTPUT_DIR/$(basename "$SRC_LIB")" + LIBS_TO_CHECK="$JNI_LIB" + for bundled_lib in "$OUTPUT_DIR"/*.so*; do + [ -f "$bundled_lib" ] && LIBS_TO_CHECK="$LIBS_TO_CHECK $bundled_lib" + done + + LIBS_CHECKED="" + while [ -n "$LIBS_TO_CHECK" ]; do + CURRENT_LIB=$(echo "$LIBS_TO_CHECK" | awk '{print $1}') + LIBS_TO_CHECK=$(echo "$LIBS_TO_CHECK" | cut -d' ' -f2-) + [ "$LIBS_TO_CHECK" = "$CURRENT_LIB" ] && LIBS_TO_CHECK="" + + # Skip already-checked + echo "$LIBS_CHECKED" | grep -q "$CURRENT_LIB" 2>/dev/null && continue + LIBS_CHECKED="$LIBS_CHECKED $CURRENT_LIB" + + [ ! -f "$CURRENT_LIB" ] && continue + + echo " Checking deps of: $(basename "$CURRENT_LIB")" + + DEPS=$(ldd "$CURRENT_LIB" 2>/dev/null | grep "=>" | awk '{print $1 " " $3}') || true + + while IFS= read -r dep_line; do + [ -z "$dep_line" ] && continue + DEP_NAME=$(echo "$dep_line" | awk '{print $1}') + DEP_PATH=$(echo "$dep_line" | awk '{print $2}') + + # Skip universally-available system libraries + case "$DEP_NAME" in + linux-vdso.so*|libc.so*|libm.so*|libpthread.so*|libdl.so*|librt.so*|ld-linux*) + continue + ;; + esac + + # Bundle known problematic libraries + case "$DEP_NAME" in + libgcc_s*) + bundle_lib "$DEP_PATH" "libgcc_s.so.1" || true + ;; + libstdc++*) + if bundle_lib "$DEP_PATH" "libstdc++.so.6"; then + LIBS_TO_CHECK="$LIBS_TO_CHECK $OUTPUT_DIR/libstdc++.so.6" + fi + ;; + libgomp*) + if bundle_lib "$DEP_PATH" "libgomp.so.1"; then + LIBS_TO_CHECK="$LIBS_TO_CHECK $OUTPUT_DIR/libgomp.so.1" + fi + ;; + libquadmath*) + bundle_lib "$DEP_PATH" "libquadmath.so.0" || true + ;; + libgfortran*) + bundle_lib "$DEP_PATH" "libgfortran.so.3" || true + ;; + esac + done <<< "$DEPS" + done + + # ---- Set rpath to $ORIGIN so bundled libs are found at load time ---- + if command -v patchelf &>/dev/null; then + echo "" + echo "Setting rpath to \$ORIGIN for all libraries..." + for lib in "$OUTPUT_DIR"/*.so*; do + if [ -f "$lib" ]; then + patchelf --set-rpath '$ORIGIN' "$lib" 2>/dev/null || true + fi + done + echo "Done setting rpath" + else + echo "" + echo "WARNING: patchelf not found, cannot set rpath." + echo " Install with: sudo apt-get install patchelf" + echo " The Java loader will still pre-load deps from JAR, but setting" + echo " rpath provides an additional safety net." + fi + +elif [ "$OS" = "Darwin" ]; then + # On macOS, Rust cdylibs are normally self-contained. + # But check if any non-system dylibs are referenced. + echo "" + echo "Checking macOS dylib dependencies..." + DYLIB_PATH="$OUTPUT_DIR/$(basename "$SRC_LIB")" + otool -L "$DYLIB_PATH" 2>/dev/null | tail -n +2 | while read -r dep_entry; do + dep_path=$(echo "$dep_entry" | awk '{print $1}') + case "$dep_path" in + /usr/lib/*|/System/*|@rpath/*|@loader_path/*|@executable_path/*) + # System or relative — OK + ;; + *) + if [ -f "$dep_path" ]; then + dep_basename=$(basename "$dep_path") + if [ ! -f "$OUTPUT_DIR/$dep_basename" ]; then + echo " Bundling macOS dep: $dep_path -> $dep_basename" + cp "$dep_path" "$OUTPUT_DIR/$dep_basename" + chmod +x "$OUTPUT_DIR/$dep_basename" + # Rewrite the install name so the JNI lib finds the bundled copy + install_name_tool -change "$dep_path" "@loader_path/$dep_basename" "$DYLIB_PATH" 2>/dev/null || true + fi + fi + ;; + esac + done +fi + +# ===================================================================== +# Summary: list all libraries and their dependencies +# ===================================================================== + +echo "" +echo "============================================" +echo "Native library summary" +echo "============================================" + +BUILT_LIBS=$(find "$PROJECT_DIR/src/main/resources" -type f \( -name "*.so" -o -name "*.so.*" -o -name "*.dylib" \) 2>/dev/null) + +if [ -n "$BUILT_LIBS" ]; then + for lib in $BUILT_LIBS; do + echo "" + echo "Library: $lib" + ls -la "$lib" + + echo "" + echo "Dependencies:" + if [ "$OS" = "Darwin" ]; then + otool -L "$lib" 2>/dev/null | head -20 || true + elif [ "$OS" = "Linux" ]; then + ldd "$lib" 2>/dev/null | head -20 || readelf -d "$lib" 2>/dev/null | grep NEEDED | head -20 || true + fi + done +else + echo " (no libraries found)" + ls -la "$PROJECT_DIR/src/main/resources/"*/*/ 2>/dev/null || true +fi + +echo "" +echo "To package the JAR with native libraries, run:" +echo " mvn package" diff --git a/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnn.java b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnn.java new file mode 100644 index 000000000000..dd252b82671c --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnn.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +/** Global DiskANN configuration and utilities. */ +public final class DiskAnn { + + static { + try { + NativeLibraryLoader.load(); + } catch (DiskAnnException e) { + // Library loading failed silently during class init. + // Callers should check isLibraryLoaded() or call loadLibrary() explicitly. + } + } + + private DiskAnn() {} + + /** + * Ensure the native library is loaded. + * + *

This method is called automatically when any DiskANN class is used. It can be called + * explicitly to load the library early and catch any loading errors. + * + * @throws DiskAnnException if the native library cannot be loaded + */ + public static void loadLibrary() throws DiskAnnException { + NativeLibraryLoader.load(); + } + + /** + * Check if the native library has been loaded. + * + * @return true if the library is loaded + */ + public static boolean isLibraryLoaded() { + return NativeLibraryLoader.isLoaded(); + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnnException.java b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnnException.java new file mode 100644 index 000000000000..21cbec85c41c --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnnException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +/** Exception for DiskANN JNI failures. */ +public class DiskAnnException extends RuntimeException { + + public DiskAnnException(String message) { + super(message); + } + + public DiskAnnException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnnNative.java b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnnNative.java new file mode 100644 index 000000000000..035f6e6de5b5 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/DiskAnnNative.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +import java.nio.ByteBuffer; + +/** + * Native method declarations for DiskANN JNI with zero-copy support. + * + *

Users should not call these methods directly. Instead, use the high-level Java API classes + * like {@link Index}. + */ +final class DiskAnnNative { + + static { + try { + NativeLibraryLoader.load(); + } catch (DiskAnnException e) { + // Library loading failed silently during class init. + // Native methods will throw UnsatisfiedLinkError if called without the library. + } + } + + /** Create a DiskANN index with the given parameters. */ + static native long indexCreate( + int dimension, int metricType, int indexType, int maxDegree, int buildListSize); + + /** Destroy an index and free its resources. */ + static native void indexDestroy(long handle); + + /** Get the number of vectors in an index. */ + static native long indexGetCount(long handle); + + /** Get the metric type of an index. */ + static native int indexGetMetricType(long handle); + + /** Add vectors to an index using a direct ByteBuffer (zero-copy). */ + static native void indexAdd(long handle, long n, ByteBuffer vectorBuffer); + + /** Build the index graph after adding vectors. */ + static native void indexBuild(long handle, int buildListSize); + + /** + * Search for the k nearest neighbors. + * + * @param handle the native handle of the index + * @param n the number of query vectors + * @param queryVectors the query vectors (n * dimension floats) + * @param k the number of nearest neighbors to find + * @param searchListSize the size of the search list (DiskANN L parameter) + * @param distances output array for distances (n * k floats) + * @param labels output array for labels (n * k longs) + */ + static native void indexSearch( + long handle, + long n, + float[] queryVectors, + int k, + int searchListSize, + float[] distances, + long[] labels); + + /** + * Serialize an index with its graph adjacency lists to a direct ByteBuffer. + * + *

The format stores the Vamana graph structure alongside vector data, so the graph can be + * loaded for search without re-building from scratch. + * + * @param handle the native handle of the in-memory index. + * @param buffer a direct ByteBuffer at least {@link #indexSerializeSize} bytes. + * @return the number of bytes written. + */ + static native long indexSerialize(long handle, ByteBuffer buffer); + + /** Return the number of bytes needed for serialization. */ + static native long indexSerializeSize(long handle); + + /** + * Create a search-only index from two on-demand readers: one for the graph structure + * and one for vectors. + * + *

Neither the graph data nor the vector data is loaded into Java memory upfront. Instead: + * + *

    + *
  • Graph: the Rust side calls {@code graphReader.readNeighbors(int)} via JNI to + * fetch neighbor lists on demand during beam search. It also calls getter methods ({@code + * getDimension()}, {@code getCount()}, {@code getStartId()}) during initialization. + *
  • Vectors: the Rust side calls {@code vectorReader.loadVector(long)} via JNI + * (zero-copy via DirectByteBuffer). + *
  • PQ: the PQ codebook and compressed codes are loaded into native memory for + * in-memory approximate distance computation during beam search. Only the final top-K + * candidates are re-ranked with full-precision vectors from disk. + *
+ * + * @param graphReader a Java object providing graph structure on demand. + * @param vectorReader a Java object with a {@code loadVector(long)} method. + * @param minExtId minimum external ID for this index (for int_id → ext_id conversion). + * @param pqPivots serialized PQ codebook (pivots), or null/empty to disable PQ. + * @param pqCompressed serialized PQ compressed codes, or null/empty to disable PQ. + * @return a searcher handle (>= 100 000) for use with {@link #indexSearchWithReader}. + */ + static native long indexCreateSearcherFromReaders( + Object graphReader, + Object vectorReader, + long minExtId, + byte[] pqPivots, + byte[] pqCompressed); + + /** + * Search on a searcher handle created by {@link #indexCreateSearcherFromReaders}. + * + * @see #indexSearch for parameter descriptions — semantics are identical. + */ + static native void indexSearchWithReader( + long handle, + long n, + float[] queryVectors, + int k, + int searchListSize, + float[] distances, + long[] labels); + + /** Destroy a searcher handle and free its resources. */ + static native void indexDestroySearcher(long handle); + + /** + * Train a Product Quantization codebook on the vectors stored in the index and encode all + * vectors into compact PQ codes. + * + *

The index must have had vectors added via {@link #indexAddWithIds} before calling this + * method. + * + * @param handle the native index handle. + * @param numSubspaces number of PQ subspaces (M). Dimension must be divisible by M. + * @param maxSamples maximum number of vectors sampled for K-Means training. + * @param kmeansIters number of K-Means iterations. + * @return {@code byte[2]} where {@code [0]} is the serialized PQ pivots (codebook) and {@code + * [1]} is the serialized compressed PQ codes. + */ + static native byte[][] pqTrainAndEncode( + long handle, int numSubspaces, int maxSamples, int kmeansIters); +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/Index.java b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/Index.java new file mode 100644 index 000000000000..5291aa045fd7 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/Index.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * A DiskANN index for similarity search with zero-copy support. + * + *

Thread Safety: Index instances are NOT thread-safe. External synchronization is required if an + * index is accessed from multiple threads. + */ +public class Index implements AutoCloseable { + + /** Native handle to the DiskANN index. */ + private long nativeHandle; + + /** The dimension of vectors in this index. */ + private final int dimension; + + /** Whether this index has been closed. */ + private volatile boolean closed = false; + + Index(long nativeHandle, int dimension) { + this.nativeHandle = nativeHandle; + this.dimension = dimension; + } + + public int getDimension() { + return dimension; + } + + public long getCount() { + checkNotClosed(); + return DiskAnnNative.indexGetCount(nativeHandle); + } + + public MetricType getMetricType() { + checkNotClosed(); + return MetricType.fromValue(DiskAnnNative.indexGetMetricType(nativeHandle)); + } + + public void add(long n, ByteBuffer vectorBuffer) { + checkNotClosed(); + validateDirectBuffer(vectorBuffer, n * dimension * Float.BYTES, "vector"); + DiskAnnNative.indexAdd(nativeHandle, n, vectorBuffer); + } + + public void build(int buildListSize) { + checkNotClosed(); + DiskAnnNative.indexBuild(nativeHandle, buildListSize); + } + + public void search( + long n, + float[] queryVectors, + int k, + int searchListSize, + float[] distances, + long[] labels) { + checkNotClosed(); + if (queryVectors.length < n * dimension) { + throw new IllegalArgumentException( + "Query vectors array too small: required " + + (n * dimension) + + ", got " + + queryVectors.length); + } + if (distances.length < n * k) { + throw new IllegalArgumentException( + "Distances array too small: required " + (n * k) + ", got " + distances.length); + } + if (labels.length < n * k) { + throw new IllegalArgumentException( + "Labels array too small: required " + (n * k) + ", got " + labels.length); + } + DiskAnnNative.indexSearch( + nativeHandle, n, queryVectors, k, searchListSize, distances, labels); + } + + /** Return the number of bytes needed for serialization. */ + public long serializeSize() { + checkNotClosed(); + return DiskAnnNative.indexSerializeSize(nativeHandle); + } + + /** + * Serialize the Vamana graph adjacency lists and vectors into the given direct ByteBuffer. + * + *

The serialized data contains the graph section followed by the vector section, with no + * header. Metadata (dimension, metric, etc.) is stored separately in {@code DiskAnnIndexMeta}. + * + * @return the number of bytes written + */ + public long serialize(ByteBuffer buffer) { + checkNotClosed(); + if (!buffer.isDirect()) { + throw new IllegalArgumentException("Buffer must be a direct buffer"); + } + return DiskAnnNative.indexSerialize(nativeHandle, buffer); + } + + public static Index create( + int dimension, MetricType metricType, int indexType, int maxDegree, int buildListSize) { + long handle = + DiskAnnNative.indexCreate( + dimension, metricType.value(), indexType, maxDegree, buildListSize); + return new Index(handle, dimension); + } + + /** + * Train a PQ codebook on the vectors in this index and encode all vectors. + * + * @param numSubspaces number of PQ subspaces (M). + * @param maxSamples maximum training samples for K-Means. + * @param kmeansIters number of K-Means iterations. + * @return {@code byte[2]}: [0] = serialized pivots, [1] = serialized compressed codes. + */ + public byte[][] pqTrainAndEncode(int numSubspaces, int maxSamples, int kmeansIters) { + checkNotClosed(); + return DiskAnnNative.pqTrainAndEncode(nativeHandle, numSubspaces, maxSamples, kmeansIters); + } + + public static ByteBuffer allocateVectorBuffer(int numVectors, int dimension) { + return ByteBuffer.allocateDirect(numVectors * dimension * Float.BYTES) + .order(ByteOrder.nativeOrder()); + } + + private void validateDirectBuffer(ByteBuffer buffer, long requiredBytes, String name) { + if (!buffer.isDirect()) { + throw new IllegalArgumentException(name + " buffer must be a direct buffer"); + } + if (buffer.capacity() < requiredBytes) { + throw new IllegalArgumentException( + name + + " buffer too small: required " + + requiredBytes + + " bytes, got " + + buffer.capacity()); + } + } + + private void checkNotClosed() { + if (closed) { + throw new IllegalStateException("Index has been closed"); + } + } + + @Override + public synchronized void close() { + if (!closed) { + closed = true; + if (nativeHandle != 0) { + DiskAnnNative.indexDestroy(nativeHandle); + nativeHandle = 0; + } + } + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/IndexSearcher.java b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/IndexSearcher.java new file mode 100644 index 000000000000..1c15436e59c8 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/IndexSearcher.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +import java.io.Closeable; + +/** + * A search-only DiskANN index backed by Paimon FileIO (local, HDFS, S3, OSS, etc.). + * + *

Both the Vamana graph and full-precision vectors are read on-demand from FileIO-backed storage + * during beam search. The Rust JNI code invokes Java reader callbacks: + * + *

    + *
  • {@code graphReader.readNeighbors(int)} — fetches graph neighbor lists from the {@code + * .index} file via {@code SeekableInputStream}. + *
  • {@code vectorReader.readVector(long)} — fetches full-precision vectors from the {@code + * .data} file via {@code SeekableInputStream}. + *
+ * + *

Frequently accessed data is cached (graph neighbors in a {@code DashMap}, vectors in an LRU + * cache) to reduce FileIO/JNI round-trips. + * + *

Thread Safety: instances are not thread-safe. + */ +public class IndexSearcher implements AutoCloseable { + + /** Native searcher handle (≥100 000, distinct from Index handles). */ + private long nativeHandle; + + private final Closeable graphReader; + + private final Closeable vectorReader; + + /** Vector dimension. */ + private final int dimension; + + private volatile boolean closed = false; + + private IndexSearcher( + long nativeHandle, int dimension, Closeable graphReader, Closeable vectorReader) { + this.nativeHandle = nativeHandle; + this.dimension = dimension; + this.graphReader = graphReader; + this.vectorReader = vectorReader; + } + + /** + * Create a search-only index from two on-demand readers. + * + *

Neither the graph nor the vector data is loaded into Java memory upfront. The Rust JNI + * code invokes: + * + *

    + *
  • {@code graphReader.readNeighbors(int)} for neighbor lists during beam search + *
  • {@code vectorReader.loadVector(long)} for full-precision vectors (zero-copy via + * DirectByteBuffer) + *
+ * + *

When PQ data is provided, beam search uses PQ-reconstructed vectors for approximate + * distance computation (fully in-memory), and only the final top-K candidates are re-ranked + * with full-precision vectors from disk I/O. + * + *

Both readers must implement {@link Closeable}. + * + * @param graphReader a graph reader object (e.g. {@code FileIOGraphReader}). + * @param vectorReader a vector reader object (e.g. {@code FileIOVectorReader}). + * @param dimension the vector dimension (from {@code DiskAnnIndexMeta}). + * @param minExtId minimum external ID for this index (for int_id → ext_id conversion). + * @param pqPivots serialized PQ codebook, or null to disable PQ-accelerated search. + * @param pqCompressed serialized PQ compressed codes, or null to disable. + * @return a new IndexSearcher + */ + public static IndexSearcher createFromReaders( + Closeable graphReader, + Closeable vectorReader, + int dimension, + long minExtId, + byte[] pqPivots, + byte[] pqCompressed) { + long handle = + DiskAnnNative.indexCreateSearcherFromReaders( + graphReader, vectorReader, minExtId, pqPivots, pqCompressed); + return new IndexSearcher(handle, dimension, graphReader, vectorReader); + } + + /** Return the vector dimension of this index. */ + public int getDimension() { + return dimension; + } + + /** + * Search for the k nearest neighbors. + * + *

Semantics are identical to {@link Index#search}. During beam search the Rust code will + * invoke the vector reader's {@code readVector} to fetch vectors it has not yet cached. + */ + public void search( + long n, + float[] queryVectors, + int k, + int searchListSize, + float[] distances, + long[] labels) { + checkNotClosed(); + if (queryVectors.length < n * dimension) { + throw new IllegalArgumentException( + "Query vectors array too small: required " + + (n * dimension) + + ", got " + + queryVectors.length); + } + if (distances.length < n * k) { + throw new IllegalArgumentException( + "Distances array too small: required " + (n * k) + ", got " + distances.length); + } + if (labels.length < n * k) { + throw new IllegalArgumentException( + "Labels array too small: required " + (n * k) + ", got " + labels.length); + } + DiskAnnNative.indexSearchWithReader( + nativeHandle, n, queryVectors, k, searchListSize, distances, labels); + } + + private void checkNotClosed() { + if (closed) { + throw new IllegalStateException("IndexSearcher has been closed"); + } + } + + @Override + public synchronized void close() { + if (!closed) { + closed = true; + if (nativeHandle != 0) { + DiskAnnNative.indexDestroySearcher(nativeHandle); + nativeHandle = 0; + } + try { + if (graphReader != null) { + graphReader.close(); + } + if (vectorReader != null) { + vectorReader.close(); + } + } catch (Exception e) { + // best-effort + } + } + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/MetricType.java b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/MetricType.java new file mode 100644 index 000000000000..d995866a91d6 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/MetricType.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +/** DiskANN metric type mappings for JNI. */ +public enum MetricType { + L2(0), + INNER_PRODUCT(1), + COSINE(2); + + private final int value; + + MetricType(int value) { + this.value = value; + } + + public int value() { + return value; + } + + public static MetricType fromValue(int value) { + for (MetricType type : values()) { + if (type.value == value) { + return type; + } + } + throw new IllegalArgumentException("Unknown metric type value: " + value); + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/NativeLibraryLoader.java b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/NativeLibraryLoader.java new file mode 100644 index 000000000000..2429e4b8fd57 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/java/org/apache/paimon/diskann/NativeLibraryLoader.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; + +/** + * Native library loader for DiskANN JNI. + * + *

The loader attempts to load the library in the following order: + * + *

    + *
  1. From the path specified by the {@code paimon.diskann.lib.path} system property + *
  2. From the system library path using {@code System.loadLibrary} + *
  3. From the JAR file bundled with the distribution + *
+ */ +public class NativeLibraryLoader { + private static final Logger LOG = LoggerFactory.getLogger(NativeLibraryLoader.class); + + /** The name of the native library. */ + private static final String JNI_LIBRARY_NAME = "paimon_diskann_jni"; + + /** System property to specify a custom path to the native library. */ + private static final String LIBRARY_PATH_PROPERTY = "paimon.diskann.lib.path"; + + /** + * Dependency libraries that need to be loaded before the main JNI library. These are bundled in + * the JAR when the build script detects they are dynamically linked. + * + *

Order matters! Libraries must be loaded before the libraries that depend on them. The Rust + * {@code cdylib} statically links all Rust code but may dynamically link against the GCC + * runtime and C++ standard library on Linux. + */ + private static final String[] DEPENDENCY_LIBRARIES = { + // GCC runtime (Rust cdylib uses libgcc_s for stack unwinding on Linux) + "libgcc_s.so.1", + // C++ standard library (needed if diskann crate compiles C++ code internally) + "libstdc++.so.6", + // OpenMP runtime (possible transitive dependency) + "libgomp.so.1", + }; + + /** Whether the native library has been loaded. */ + private static volatile boolean libraryLoaded = false; + + /** Lock for thread-safe library loading. */ + private static final Object LOAD_LOCK = new Object(); + + /** Temporary directory for extracting native libraries. */ + private static Path tempDir; + + private NativeLibraryLoader() { + // Utility class, no instantiation + } + + /** + * Load the native library. + * + * @throws DiskAnnException if the library cannot be loaded + */ + public static void load() throws DiskAnnException { + if (libraryLoaded) { + return; + } + + synchronized (LOAD_LOCK) { + if (libraryLoaded) { + return; + } + + try { + loadNativeLibrary(); + libraryLoaded = true; + LOG.info("DiskANN native library loaded successfully"); + } catch (Exception e) { + throw new DiskAnnException("Failed to load DiskANN native library", e); + } + } + } + + /** + * Check if the native library has been loaded. + * + * @return true if the library is loaded + */ + public static boolean isLoaded() { + return libraryLoaded; + } + + private static void loadNativeLibrary() throws IOException { + // First, try loading from custom path + String customPath = System.getProperty(LIBRARY_PATH_PROPERTY); + if (customPath != null && !customPath.isEmpty()) { + File customLibrary = new File(customPath); + if (customLibrary.exists()) { + System.load(customLibrary.getAbsolutePath()); + LOG.info("Loaded DiskANN native library from custom path: {}", customPath); + return; + } else { + LOG.warn("Custom library path specified but file not found: {}", customPath); + } + } + + // Second, try loading from system library path + try { + System.loadLibrary(JNI_LIBRARY_NAME); + LOG.info("Loaded DiskANN native library from system path"); + return; + } catch (UnsatisfiedLinkError e) { + LOG.debug( + "Could not load from system path, trying bundled library: {}", e.getMessage()); + } + + // Third, try loading from JAR + loadFromJar(); + } + + private static void loadFromJar() throws IOException { + String libraryPath = getLibraryResourcePath(); + LOG.debug("Attempting to load native library from JAR: {}", libraryPath); + + try (InputStream is = NativeLibraryLoader.class.getResourceAsStream(libraryPath)) { + if (is == null) { + throw new IOException( + "Native library not found in JAR: " + + libraryPath + + ". " + + "Make sure you are using the correct JAR for your platform (" + + getPlatformIdentifier() + + ")"); + } + + // Create temp directory if needed + if (tempDir == null) { + tempDir = Files.createTempDirectory("paimon-diskann-native"); + tempDir.toFile().deleteOnExit(); + } + + // Extract and load dependency libraries (if bundled) + loadDependencyLibraries(); + + // Extract native library to temp file + String fileName = System.mapLibraryName(JNI_LIBRARY_NAME); + File tempFile = new File(tempDir.toFile(), fileName); + tempFile.deleteOnExit(); + + try (OutputStream os = new FileOutputStream(tempFile)) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = is.read(buffer)) != -1) { + os.write(buffer, 0, bytesRead); + } + } + + // Make the file executable (for Unix-like systems) + if (!tempFile.setExecutable(true)) { + LOG.warn("Could not set executable permission on native library"); + } + + // Load the library + System.load(tempFile.getAbsolutePath()); + LOG.info("Loaded DiskANN native library from JAR: {}", libraryPath); + } + } + + private static void loadDependencyLibraries() { + String os = getOsName(); + String arch = getArchName(); + + for (String depLib : DEPENDENCY_LIBRARIES) { + String resourcePath = "/" + os + "/" + arch + "/" + depLib; + try (InputStream is = NativeLibraryLoader.class.getResourceAsStream(resourcePath)) { + if (is == null) { + LOG.warn("Dependency library not bundled: {}", depLib); + continue; + } + + File tempFile = new File(tempDir.toFile(), depLib); + tempFile.deleteOnExit(); + + try (OutputStream fos = new FileOutputStream(tempFile)) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = is.read(buffer)) != -1) { + fos.write(buffer, 0, bytesRead); + } + } + + if (!tempFile.setExecutable(true)) { + LOG.warn("Could not set executable permission on: {}", depLib); + } + + System.load(tempFile.getAbsolutePath()); + LOG.info("Loaded bundled dependency library: {}", depLib); + } catch (UnsatisfiedLinkError e) { + LOG.warn("Could not load dependency {}: {}", depLib, e.getMessage()); + } catch (IOException e) { + LOG.warn("Could not extract dependency {}: {}", depLib, e.getMessage()); + } + } + } + + private static String getLibraryResourcePath() { + String os = getOsName(); + String arch = getArchName(); + String libraryFileName = System.mapLibraryName(JNI_LIBRARY_NAME); + return "/" + os + "/" + arch + "/" + libraryFileName; + } + + static String getPlatformIdentifier() { + return getOsName() + "/" + getArchName(); + } + + private static String getOsName() { + String osName = System.getProperty("os.name").toLowerCase(); + + if (osName.contains("linux")) { + return "linux"; + } else if (osName.contains("mac") || osName.contains("darwin")) { + return "darwin"; + } else { + throw new UnsupportedOperationException( + "Unsupported operating system: " + + osName + + ". Only Linux and macOS are supported."); + } + } + + private static String getArchName() { + String osArch = System.getProperty("os.arch").toLowerCase(); + + if (osArch.equals("amd64") || osArch.equals("x86_64")) { + return "amd64"; + } else if (osArch.equals("aarch64") || osArch.equals("arm64")) { + return "aarch64"; + } else { + throw new UnsupportedOperationException("Unsupported architecture: " + osArch); + } + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/native/Cargo.lock b/paimon-diskann/paimon-diskann-jni/src/main/native/Cargo.lock new file mode 100644 index 000000000000..2730f7b3d5db --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/native/Cargo.lock @@ -0,0 +1,768 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "anyhow" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "diskann" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52046f3e2d811ff08d458ddc649aaecddf2ba404c60086e556fb09436f3a9f4f" +dependencies = [ + "anyhow", + "bytemuck", + "dashmap", + "diskann-utils", + "diskann-vector", + "diskann-wide", + "futures-util", + "half", + "hashbrown 0.16.1", + "num-traits", + "rand", + "thiserror 2.0.18", + "tokio", + "tracing", +] + +[[package]] +name = "diskann-quantization" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6703f9a574be9bf6c9f1f59033e127f5300cd240e50fc40dffe9a64475f92d5" +dependencies = [ + "bytemuck", + "cfg-if", + "diskann-utils", + "diskann-vector", + "diskann-wide", + "half", + "rand", + "rayon", + "thiserror 2.0.18", +] + +[[package]] +name = "diskann-utils" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89752de1b587b64aedca61a53f92bf82903df17fe483413ae5cfc6d061fe2cd3" +dependencies = [ + "cfg-if", + "diskann-vector", + "diskann-wide", + "half", + "rand", + "rayon", + "thiserror 2.0.18", +] + +[[package]] +name = "diskann-vector" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6e50f7003938f1572e8a8b91e81fd2693bd40413e8d477a2e440da278bca7c" +dependencies = [ + "cfg-if", + "diskann-wide", + "half", +] + +[[package]] +name = "diskann-wide" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4507634ff2929569ea3b60d8f60ce22ff4c3b04d5aca864b6732a9e3f03f99" +dependencies = [ + "cfg-if", + "half", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-core", + "futures-macro", + "futures-task", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "foldhash", +] + +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + +[[package]] +name = "libc" +version = "0.2.181" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "459427e2af2b9c839b132acb702a1c654d95e10f8c326bfc2ad11310e458b1c5" + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "paimon_diskann_jni" +version = "0.1.0" +dependencies = [ + "dashmap", + "diskann", + "diskann-quantization", + "diskann-utils", + "diskann-vector", + "futures-util", + "jni", + "tokio", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio" +version = "1.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +dependencies = [ + "pin-project-lite", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "unicode-ident" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "zerocopy" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/paimon-diskann/paimon-diskann-jni/src/main/native/Cargo.toml b/paimon-diskann/paimon-diskann-jni/src/main/native/Cargo.toml new file mode 100644 index 000000000000..538d90bb393f --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/native/Cargo.toml @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[package] +name = "paimon_diskann_jni" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +jni = "0.21" +# DiskANN with testing feature: Provides test_provider for in-memory vector storage +# Note: Despite the name "testing", this is used in production for the in-memory index implementation +diskann = { version = "0.45.0", features = ["testing"] } +diskann-vector = "0.45.0" +# Tokio with rt feature: Required for async DiskANN operations +# Uses single-threaded runtime (new_current_thread) for efficient resource usage +tokio = { version = "1", features = ["rt"] } +# Futures utilities (for futures_util::future::ok in JniProvider) +futures-util = "0.3" +# Concurrent HashMap used by the JniProvider graph storage +dashmap = "6" +diskann-quantization = "0.45.0" +diskann-utils = "0.45.0" diff --git a/paimon-diskann/paimon-diskann-jni/src/main/native/rust-toolchain.toml b/paimon-diskann/paimon-diskann-jni/src/main/native/rust-toolchain.toml new file mode 100644 index 000000000000..6e8dd640a2ef --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/native/rust-toolchain.toml @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Microsoft DiskANN v0.45.0 requires Rust >= 1.87 due to +# diskann-vector's use of unsigned_is_multiple_of (stabilised in 1.87). +[toolchain] +channel = "1.90.0" diff --git a/paimon-diskann/paimon-diskann-jni/src/main/native/src/lib.rs b/paimon-diskann/paimon-diskann-jni/src/main/native/src/lib.rs new file mode 100644 index 000000000000..27e04f72f101 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/native/src/lib.rs @@ -0,0 +1,1379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! JNI bindings for Apache Paimon's DiskANN vector index. +//! +//! This module uses Microsoft's official `diskann` Rust crate (v0.45.0) +//! from to provide graph-based +//! approximate nearest neighbor search via JNI. +//! +//! # JNI Safety +//! +//! Every `extern "system"` entry point is wrapped with [`std::panic::catch_unwind`] +//! so that a Rust panic never unwinds across the FFI boundary, which would cause +//! undefined behaviour and likely crash the JVM. On panic the function throws a +//! `java.lang.RuntimeException` with the panic message and returns a safe default. + +use jni::objects::{JByteBuffer, JClass, JObject, JPrimitiveArray, ReleaseMode}; +use jni::sys::{jfloat, jint, jlong}; +use jni::JNIEnv; + +use std::collections::HashMap; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::{Arc, Mutex, OnceLock}; + +use diskann::graph::test::provider as test_provider; +use diskann::graph::{self, DiskANNIndex}; +use diskann::neighbor::{BackInserter, Neighbor}; +use diskann_vector::distance::Metric; + +mod paimon_fileio_provider; +mod pq; +use paimon_fileio_provider::FileIOProvider; + +// ======================== Constants ======================== + +const METRIC_L2: i32 = 0; +const METRIC_INNER_PRODUCT: i32 = 1; +const METRIC_COSINE: i32 = 2; + +/// The u32 ID reserved for the DiskANN graph start/entry point. +const START_POINT_ID: u32 = 0; + +// ======================== Panic‐safe JNI helper ======================== + +/// Run `body` inside [`catch_unwind`]. If it panics, throw a Java +/// `RuntimeException` with the panic message and return `default`. +fn jni_catch_unwind(env: &mut JNIEnv, default: R, body: F) -> R +where + F: FnOnce() -> R + panic::UnwindSafe, +{ + match panic::catch_unwind(body) { + Ok(v) => v, + Err(payload) => { + let msg = if let Some(s) = payload.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = payload.downcast_ref::() { + s.clone() + } else { + "Unknown Rust panic in DiskANN JNI".to_string() + }; + let _ = env.throw_new("java/lang/RuntimeException", msg); + default + } + } +} + +// ======================== Metric Mapping ======================== + +pub(crate) fn map_metric(metric_type: i32) -> Metric { + match metric_type { + METRIC_INNER_PRODUCT => Metric::InnerProduct, + METRIC_COSINE => Metric::Cosine, + _ => Metric::L2, + } +} + +// ======================== Index State ======================== + +struct IndexState { + index: Arc>, + context: test_provider::Context, + runtime: tokio::runtime::Runtime, + + dimension: i32, + metric_type: i32, + + next_id: u32, + + /// Vectors stored in insertion order. Position i has int_id = i + 1. + raw_data: Vec>, +} + +// ======================== Registry ======================== + +struct IndexRegistry { + next_handle: i64, + indices: HashMap>>, +} + +impl IndexRegistry { + fn new() -> Self { + Self { + next_handle: 1, + indices: HashMap::new(), + } + } + + fn insert(&mut self, state: IndexState) -> i64 { + let handle = self.next_handle; + self.next_handle += 1; + self.indices.insert(handle, Arc::new(Mutex::new(state))); + handle + } +} + +fn registry() -> &'static Mutex { + static REGISTRY: OnceLock> = OnceLock::new(); + REGISTRY.get_or_init(|| Mutex::new(IndexRegistry::new())) +} + +fn get_index(handle: i64) -> Option>> { + let guard = registry().lock().ok()?; + guard.indices.get(&handle).cloned() +} + +// ======================== Index Construction ======================== + +fn create_index_state( + dimension: i32, + metric_type: i32, + _index_type: i32, + max_degree: i32, + build_list_size: i32, +) -> Result { + let dim = dimension as usize; + let metric = map_metric(metric_type); + let md = std::cmp::max(max_degree as usize, 4); + let bls = std::cmp::max(build_list_size as usize, md); + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|e| format!("Failed to create tokio runtime: {}", e))?; + + let start_vector = vec![1.0f32; dim]; + let provider_config = test_provider::Config::new( + metric, + md, + test_provider::StartPoint::new(START_POINT_ID, start_vector), + ) + .map_err(|e| format!("Failed to create provider config: {:?}", e))?; + let provider = test_provider::Provider::new(provider_config); + + let index_config = graph::config::Builder::new( + md, + graph::config::MaxDegree::same(), + bls, + metric.into(), + ) + .build() + .map_err(|e| format!("Failed to create index config: {:?}", e))?; + + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); + let context = test_provider::Context::default(); + + Ok(IndexState { + index, + context, + runtime, + dimension, + metric_type, + next_id: START_POINT_ID + 1, + raw_data: Vec::new(), + }) +} + +// ======================== Buffer Helpers ======================== + +fn get_direct_buffer_slice<'a>( + env: &mut JNIEnv, + buffer: &JByteBuffer, + len: usize, +) -> Option<&'a mut [u8]> { + let ptr = env.get_direct_buffer_address(buffer).ok()?; + let capacity = env.get_direct_buffer_capacity(buffer).ok()?; + if capacity < len { + return None; + } + unsafe { Some(std::slice::from_raw_parts_mut(ptr, len)) } +} + +// ======================== Serialization Helpers ======================== + +fn write_i32(buf: &mut [u8], offset: &mut usize, v: i32) -> bool { + if *offset + 4 > buf.len() { return false; } + buf[*offset..*offset + 4].copy_from_slice(&v.to_ne_bytes()); + *offset += 4; + true +} + +fn write_f32(buf: &mut [u8], offset: &mut usize, v: f32) -> bool { + if *offset + 4 > buf.len() { return false; } + buf[*offset..*offset + 4].copy_from_slice(&v.to_ne_bytes()); + *offset += 4; + true +} + +// ======================== JNI Functions ======================== + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexCreate<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + dimension: jint, + metric_type: jint, + index_type: jint, + max_degree: jint, + build_list_size: jint, +) -> jlong { + let result = jni_catch_unwind(&mut env, 0i64, AssertUnwindSafe(|| -> jlong { + match create_index_state(dimension, metric_type, index_type, max_degree, build_list_size) { + Ok(state) => match registry().lock() { + Ok(mut guard) => guard.insert(state), + Err(_) => -1, + }, + Err(_) => -2, + } + })); + match result { + -1 => { let _ = env.throw_new("java/lang/IllegalStateException", "DiskANN registry error"); 0 } + -2 => { let _ = env.throw_new("java/lang/RuntimeException", "Failed to create DiskANN index"); 0 } + v => v, + } +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexDestroy<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) { + jni_catch_unwind(&mut env, (), AssertUnwindSafe(|| { + if let Ok(mut guard) = registry().lock() { + guard.indices.remove(&handle); + } + })); +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexGetDimension<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) -> jint { + jni_catch_unwind(&mut env, 0, AssertUnwindSafe(|| { + get_index(handle) + .and_then(|arc| arc.lock().ok().map(|s| s.dimension)) + .unwrap_or(0) + })) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexGetCount<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) -> jlong { + jni_catch_unwind(&mut env, 0, AssertUnwindSafe(|| { + get_index(handle) + .and_then(|arc| arc.lock().ok().map(|s| s.raw_data.len() as jlong)) + .unwrap_or(0) + })) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexGetMetricType<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) -> jint { + jni_catch_unwind(&mut env, 0, AssertUnwindSafe(|| { + get_index(handle) + .and_then(|arc| arc.lock().ok().map(|s| s.metric_type)) + .unwrap_or(0) + })) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexAdd<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + n: jlong, + vector_buffer: JByteBuffer<'local>, +) { + let arc = match get_index(handle) { + Some(a) => a, + None => { + let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid index handle"); + return; + } + }; + let mut state = match arc.lock() { + Ok(s) => s, + Err(_) => { + let _ = env.throw_new("java/lang/IllegalStateException", "Index lock poisoned"); + return; + } + }; + + let num = n as usize; + let dimension = state.dimension as usize; + let vec_len = num * dimension * 4; + + let vec_bytes = match get_direct_buffer_slice(&mut env, &vector_buffer, vec_len) { + Some(slice) => slice, + None => { + let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid vector buffer"); + return; + } + }; + + let vectors = + unsafe { std::slice::from_raw_parts(vec_bytes.as_ptr() as *const f32, num * dimension) }; + + let strat = test_provider::Strategy::new(); + + for i in 0..num { + let base = i * dimension; + let vector = vectors[base..base + dimension].to_vec(); + + let int_id = state.next_id; + state.next_id += 1; + state.raw_data.push(vector.clone()); + + // catch_unwind around the DiskANN graph insert which may panic. + let idx_clone = Arc::clone(&state.index); + let ctx = &state.context; + let result = panic::catch_unwind(AssertUnwindSafe(|| { + state.runtime.block_on(idx_clone.insert(strat, ctx, &int_id, vector.as_slice())) + })); + + match result { + Ok(Ok(())) => {} + Ok(Err(e)) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("DiskANN insert failed for int_id {}: {}", int_id, e), + ); + return; + } + Err(_) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("DiskANN insert panicked for int_id {}", int_id), + ); + return; + } + } + } +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexBuild<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + _build_list_size: jint, +) { + jni_catch_unwind(&mut env, (), AssertUnwindSafe(|| { + if get_index(handle).is_none() { + // Will be caught below. + panic!("Invalid index handle"); + } + })); +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexSearch<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + n: jlong, + query_vectors: JPrimitiveArray<'local, jfloat>, + k: jint, + search_list_size: jint, + distances: JPrimitiveArray<'local, jfloat>, + labels: JPrimitiveArray<'local, jlong>, +) { + let num = n as usize; + let top_k = k as usize; + + let arc = match get_index(handle) { + Some(a) => a, + None => { + let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid index handle"); + return; + } + }; + + // Read query vectors into owned Vec. + let query: Vec = { + let query_elements = + match unsafe { env.get_array_elements(&query_vectors, ReleaseMode::NoCopyBack) } { + Ok(arr) => arr, + Err(_) => { + let _ = env.throw_new( + "java/lang/IllegalArgumentException", + "Invalid query vectors", + ); + return; + } + }; + query_elements.iter().copied().collect() + }; + + let state = match arc.lock() { + Ok(s) => s, + Err(_) => { + let _ = env.throw_new("java/lang/IllegalStateException", "Index lock poisoned"); + return; + } + }; + + let dimension = state.dimension as usize; + let total_results = num * top_k; + let mut result_distances = vec![f32::MAX; total_results]; + let mut result_labels = vec![-1i64; total_results]; + + if !state.raw_data.is_empty() { + let strat = test_provider::Strategy::new(); + + for qi in 0..num { + let query_vec = &query[qi * dimension..(qi + 1) * dimension]; + + let search_k = top_k + 1; + let l_value = std::cmp::max(search_list_size as usize, search_k); + + let params = match graph::SearchParams::new(search_k, l_value, None) { + Ok(p) => p, + Err(e) => { + let _ = env.throw_new( + "java/lang/IllegalArgumentException", + format!("Invalid search params: {}", e), + ); + return; + } + }; + + let mut neighbors = vec![Neighbor::::default(); search_k]; + + // catch_unwind around graph search. + let idx_clone = Arc::clone(&state.index); + let ctx = &state.context; + let search_result = panic::catch_unwind(AssertUnwindSafe(|| { + state.runtime.block_on(idx_clone.search( + &strat, + ctx, + query_vec, + ¶ms, + &mut BackInserter::new(&mut neighbors), + )) + })); + + let stats = match search_result { + Ok(Ok(s)) => s, + Ok(Err(e)) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("DiskANN search failed: {}", e), + ); + return; + } + Err(_) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + "DiskANN search panicked", + ); + return; + } + }; + + let result_count = stats.result_count as usize; + let mut count = 0; + for ri in 0..result_count { + if count >= top_k { + break; + } + let neighbor = &neighbors[ri]; + if neighbor.id == START_POINT_ID { + continue; + } + let idx = qi * top_k + count; + result_labels[idx] = (neighbor.id as i64) - 1; + result_distances[idx] = neighbor.distance; + count += 1; + } + } + } + + drop(state); + + // Write distances back. + { + let mut dist_elements = + match unsafe { env.get_array_elements(&distances, ReleaseMode::CopyBack) } { + Ok(arr) => arr, + Err(_) => { + let _ = + env.throw_new("java/lang/IllegalArgumentException", "Invalid distances"); + return; + } + }; + for i in 0..std::cmp::min(dist_elements.len(), result_distances.len()) { + dist_elements[i] = result_distances[i]; + } + } + + // Write labels back. + { + let mut label_elements = + match unsafe { env.get_array_elements(&labels, ReleaseMode::CopyBack) } { + Ok(arr) => arr, + Err(_) => { + let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid labels"); + return; + } + }; + for i in 0..std::cmp::min(label_elements.len(), result_labels.len()) { + label_elements[i] = result_labels[i]; + } + } +} + +// ============================================================================ +// Serialization format (graph + data, no header) +// ============================================================================ +// +// The index file (.index) contains ONLY the graph adjacency lists: +// Graph : for each node (start point + user vectors): +// int_id : i32 +// neighbor_cnt : i32 +// neighbors : neighbor_cnt × i32 +// +// The data file (.data) contains ONLY raw vectors stored sequentially: +// Data : for each user vector (in order 0, 1, 2, ...): +// vector : dim × f32 +// +// The sequential position IS the ID. +// The start point is NOT stored in the data file. +// position = int_id - 1 for user vectors (int_id > 0). +// +// All metadata (dimension, metric, max_degree, build_list_size, count, +// start_id) is stored in DiskAnnIndexMeta — not in the file. +// +// During search, both graph and vector data are read on demand from +// Paimon FileIO-backed storage (local, HDFS, S3, OSS, etc.) via JNI callbacks: +// - Graph: FileIOGraphReader.readNeighbors(int) +// - Vectors: FileIOVectorReader.readVector(long) +// ============================================================================ + +// ---- Searcher registry (handles backed by FileIOProvider) ---- + +struct SearcherState { + /// The DiskANN index (holds the FileIOProvider which has PQ data + graph/vector readers). + index: Arc>, + dimension: i32, + /// Minimum external ID for this index. ext_id = min_ext_id + vec_idx (0-based). + min_ext_id: i64, + /// Graph start/entry point (medoid) internal ID. + start_id: u32, + /// I/O context for beam search: JVM, reader refs, DirectByteBuffer pointers. + io_ctx: BeamSearchIOContext, +} + +/// I/O context for beam search. Provides JNI access to graph reader (for +/// neighbor lists) and vector reader (for full-precision vectors on disk). +struct BeamSearchIOContext { + /// JVM handle for attaching threads. + jvm: jni::JavaVM, + /// GlobalRef to the Java vector reader (`FileIOVectorReader`). + vector_reader_ref: jni::objects::GlobalRef, + /// GlobalRef to the Java graph reader (`FileIOGraphReader`). + graph_reader_ref: jni::objects::GlobalRef, + /// Native address of the single-vector DirectByteBuffer. + single_buf_ptr: *mut f32, + /// Vector dimension. + dim: usize, + /// Distance metric type (0=L2, 1=IP, 2=Cosine). + metric_type: i32, +} + +// SAFETY: same justification as FileIOProvider — JavaVM and GlobalRef are +// Send+Sync, raw pointer access is serialized by single-threaded runtime. +unsafe impl Send for BeamSearchIOContext {} +unsafe impl Sync for BeamSearchIOContext {} + +/// Compute exact distance between two vectors. +fn compute_exact_distance(a: &[f32], b: &[f32], metric_type: i32) -> f32 { + match metric_type { + METRIC_INNER_PRODUCT => { + // Negative inner product (larger IP = more similar → smaller distance). + let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum(); + -dot + } + METRIC_COSINE => { + // 1 − cos_sim + let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum(); + let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + let denom = norm_a * norm_b; + if denom < 1e-30 { 1.0 } else { 1.0 - dot / denom } + } + _ => { + // Squared L2 distance. + a.iter().zip(b).map(|(x, y)| { let d = x - y; d * d }).sum() + } + } +} + +struct SearcherRegistry { + next_handle: i64, + searchers: HashMap>>, +} + +impl SearcherRegistry { + fn new() -> Self { + Self { next_handle: 100_000, searchers: HashMap::new() } + } + fn insert(&mut self, state: SearcherState) -> i64 { + let h = self.next_handle; + self.next_handle += 1; + self.searchers.insert(h, Arc::new(Mutex::new(state))); + h + } +} + +fn searcher_registry() -> &'static Mutex { + static REG: OnceLock> = OnceLock::new(); + REG.get_or_init(|| Mutex::new(SearcherRegistry::new())) +} + +fn get_searcher(handle: i64) -> Option>> { + searcher_registry().lock().ok()?.searchers.get(&handle).cloned() +} + +// ======================== indexSerialize ======================== + +/// Serialize the index with its graph adjacency lists. +/// Returns the number of bytes written. +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexSerialize<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + buffer: JByteBuffer<'local>, +) -> jlong { + let arc = match get_index(handle) { + Some(a) => a, + None => { let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid handle"); return 0; } + }; + let state = match arc.lock() { + Ok(s) => s, + Err(_) => { let _ = env.throw_new("java/lang/IllegalStateException", "Lock poisoned"); return 0; } + }; + + // Collect graph data from the underlying DiskANN test_provider. + let provider = state.index.provider(); + + let dim = state.dimension as usize; + let num_user_vectors = state.raw_data.len(); + let num_nodes = num_user_vectors + 1; // +1 for start point + + // Build ordered list of (int_id, neighbors) using the async + // NeighborAccessor API run synchronously on our tokio runtime. + // Node order: start point (int_id=0) first, then user vectors (int_id=1,2,...). + let mut graph_section_size: usize = 0; + let mut graph_entries: Vec<(u32, Vec)> = Vec::with_capacity(num_nodes); + + for int_id in 0..num_nodes as u32 { + let mut neighbors = Vec::new(); + { + use diskann::graph::AdjacencyList; + use diskann::provider::{DefaultAccessor, NeighborAccessor as NeighborAccessorTrait}; + + let accessor = provider.default_accessor(); + let mut adj = AdjacencyList::::new(); + if state.runtime.block_on(accessor.get_neighbors(int_id, &mut adj)).is_ok() { + neighbors = adj.iter().copied().collect(); + } + } + graph_section_size += 4 + 4 + neighbors.len() * 4; // int_id + cnt + neighbors + graph_entries.push((int_id, neighbors)); + } + + // Data section: user vectors in sequential order (no start point). + let data_section_size = num_user_vectors * dim * 4; + let total_size = graph_section_size + data_section_size; + + let buf = match get_direct_buffer_slice(&mut env, &buffer, total_size) { + Some(s) => s, + None => { let _ = env.throw_new("java/lang/IllegalArgumentException", "Buffer too small"); return 0; } + }; + + let mut off = 0usize; + + // Graph section: int_id(i32) + neighbor_cnt(i32) + neighbors(cnt × i32) + for (int_id, neighbors) in &graph_entries { + write_i32(buf, &mut off, *int_id as i32); + write_i32(buf, &mut off, neighbors.len() as i32); + for &n in neighbors { + write_i32(buf, &mut off, n as i32); + } + } + + // Data section: user vectors in insertion order. + for vec in &state.raw_data { + for &v in vec { + write_f32(buf, &mut off, v); + } + } + + total_size as jlong +} + +/// Return the serialized size. +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexSerializeSize<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) -> jlong { + let arc = match get_index(handle) { + Some(a) => a, + None => { let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid handle"); return 0; } + }; + let state = match arc.lock() { + Ok(s) => s, + Err(_) => { let _ = env.throw_new("java/lang/IllegalStateException", "Lock poisoned"); return 0; } + }; + + let dim = state.dimension as usize; + + // Calculate size by iterating over all graph nodes (start point + user vectors). + let provider = state.index.provider(); + let num_nodes = state.raw_data.len() + 1; // +1 for start point + let mut graph_section_size: usize = 0; + + for int_id in 0..num_nodes as u32 { + let neighbor_count = { + use diskann::graph::AdjacencyList; + use diskann::provider::{DefaultAccessor, NeighborAccessor as NeighborAccessorTrait}; + let accessor = provider.default_accessor(); + let mut adj = AdjacencyList::::new(); + if state.runtime.block_on(accessor.get_neighbors(int_id, &mut adj)).is_ok() { + adj.len() + } else { + 0 + } + }; + graph_section_size += 4 + 4 + neighbor_count * 4; // int_id + cnt + neighbors + } + + // Data section: only user vectors (no start point). + let data_section_size = state.raw_data.len() * dim * 4; + (graph_section_size + data_section_size) as jlong +} + +// ======================== indexCreateSearcherFromReaders ======================== + +/// Create a search-only handle from two on-demand Java readers: one for graph +/// structure and one for vectors, plus PQ data for in-memory approximate +/// distance computation during beam search. +/// +/// `graphReader`: Java object with `readNeighbors(int)`, `getDimension()`, etc. +/// `vectorReader`: Java object with `loadVector(long)`, DirectByteBuffer accessors. +/// `min_ext_id`: Minimum external ID for int_id → ext_id conversion. +/// `pq_pivots`: Serialized PQ codebook (byte[]). Must not be null. +/// `pq_compressed`: Serialized PQ compressed codes (byte[]). Must not be null. +/// +/// Returns a searcher handle (≥100000) for use with `indexSearchWithReader`. +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexCreateSearcherFromReaders<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + graph_reader: JObject<'local>, + vector_reader: JObject<'local>, + min_ext_id: jlong, + pq_pivots: JObject<'local>, + pq_compressed: JObject<'local>, +) -> jlong { + // Helper to call int-returning methods on graphReader. + macro_rules! call_int { + ($name:expr) => { + match env.call_method(&graph_reader, $name, "()I", &[]) { + Ok(v) => match v.i() { Ok(i) => i, Err(_) => { let _ = env.throw_new("java/lang/RuntimeException", concat!("Bad return from ", stringify!($name))); return 0; } }, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("Failed to call {}: {}", $name, e)); return 0; } + } + } + } + + let dimension = call_int!("getDimension"); + let metric_type = call_int!("getMetricValue"); + let max_degree = call_int!("getMaxDegree") as usize; + let build_ls = call_int!("getBuildListSize") as usize; + let count = call_int!("getCount") as usize; + let start_id = call_int!("getStartId") as u32; + let dim = dimension as usize; + + // Create global refs for both readers. + let global_graph_reader = match env.new_global_ref(&graph_reader) { + Ok(g) => g, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("graph ref: {}", e)); return 0; } + }; + let global_vector_reader = match env.new_global_ref(&vector_reader) { + Ok(g) => g, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("vector ref: {}", e)); return 0; } + }; + + let jvm = match env.get_java_vm() { + Ok(vm) => vm, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("get JVM: {}", e)); return 0; } + }; + + // ---- Obtain DirectByteBuffer native pointers from the vector reader ---- + + // Single-vector DirectByteBuffer: getDirectBuffer() → ByteBuffer + let single_buf_ptr: *mut f32 = { + let buf_obj = match env.call_method(&vector_reader, "getDirectBuffer", "()Ljava/nio/ByteBuffer;", &[]) { + Ok(v) => match v.l() { Ok(o) => o, Err(_) => { let _ = env.throw_new("java/lang/RuntimeException", "Bad return from getDirectBuffer"); return 0; } }, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("getDirectBuffer: {}", e)); return 0; } + }; + let byte_buf = jni::objects::JByteBuffer::from(buf_obj); + match env.get_direct_buffer_address(&byte_buf) { + Ok(ptr) => ptr as *mut f32, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("GetDirectBufferAddress (single): {}", e)); return 0; } + } + }; + + // Batch DirectByteBuffer: getBatchBuffer() → ByteBuffer + let batch_buf_ptr: *mut f32 = { + let buf_obj = match env.call_method(&vector_reader, "getBatchBuffer", "()Ljava/nio/ByteBuffer;", &[]) { + Ok(v) => match v.l() { Ok(o) => o, Err(_) => { let _ = env.throw_new("java/lang/RuntimeException", "Bad return from getBatchBuffer"); return 0; } }, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("getBatchBuffer: {}", e)); return 0; } + }; + let byte_buf = jni::objects::JByteBuffer::from(buf_obj); + match env.get_direct_buffer_address(&byte_buf) { + Ok(ptr) => ptr as *mut f32, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("GetDirectBufferAddress (batch): {}", e)); return 0; } + } + }; + + // Max batch size from the vector reader. + let max_batch_size: usize = match env.call_method(&vector_reader, "getMaxBatchSize", "()I", &[]) { + Ok(v) => match v.i() { Ok(i) => i as usize, Err(_) => max_degree }, + Err(_) => max_degree, + }; + + // ---- Deserialize PQ data (always required — Java has validated this) ---- + + if pq_pivots.is_null() || pq_compressed.is_null() { + let _ = env.throw_new("java/lang/IllegalArgumentException", "PQ pivots and compressed data must not be null"); + return 0; + } + + let pivots_bytes: Vec = match env.convert_byte_array( + jni::objects::JByteArray::from(pq_pivots), + ) { + Ok(b) if !b.is_empty() => b, + _ => { let _ = env.throw_new("java/lang/IllegalArgumentException", "PQ pivots byte array is empty"); return 0; } + }; + let compressed_bytes: Vec = match env.convert_byte_array( + jni::objects::JByteArray::from(pq_compressed), + ) { + Ok(b) if !b.is_empty() => b, + _ => { let _ = env.throw_new("java/lang/IllegalArgumentException", "PQ compressed byte array is empty"); return 0; } + }; + + let pq_state = match paimon_fileio_provider::PQState::deserialize(&pivots_bytes, &compressed_bytes) { + Ok(pq) => pq, + Err(e) => { + let _ = env.throw_new("java/lang/RuntimeException", format!("PQ deserialization failed: {}", e)); + return 0; + } + }; + + // ---- Create beam search I/O context (separate GlobalRefs + JVM for search I/O) ---- + + let io_vector_ref = match env.new_global_ref(&vector_reader) { + Ok(g) => g, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("io vector ref: {}", e)); return 0; } + }; + let io_graph_ref = match env.new_global_ref(&graph_reader) { + Ok(g) => g, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("io graph ref: {}", e)); return 0; } + }; + let io_jvm = match env.get_java_vm() { + Ok(vm) => vm, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("io JVM: {}", e)); return 0; } + }; + let io_ctx = BeamSearchIOContext { + jvm: io_jvm, + vector_reader_ref: io_vector_ref, + graph_reader_ref: io_graph_ref, + single_buf_ptr, + dim, + metric_type, + }; + + // Start point is not stored in data file; use a dummy vector. + let start_vec = vec![1.0f32; dim]; + + // Build the FileIOProvider with on-demand graph reading and zero-copy vector access. + let provider = FileIOProvider::new_with_readers( + count, + start_id, + start_vec, + jvm, + global_vector_reader, + global_graph_reader, + dim, + metric_type, + max_degree, + single_buf_ptr, + batch_buf_ptr, + max_batch_size, + pq_state, + ); + + // Build DiskANNIndex config (still needed for the provider wrapper). + let md = std::cmp::max(max_degree, 4); + let bls = std::cmp::max(build_ls, md); + let metric = map_metric(metric_type); + + let index_config = match graph::config::Builder::new( + md, + graph::config::MaxDegree::same(), + bls, + metric.into(), + ).build() { + Ok(c) => c, + Err(e) => { let _ = env.throw_new("java/lang/RuntimeException", format!("config: {:?}", e)); return 0; } + }; + + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); + + let searcher = SearcherState { + index, + dimension, + min_ext_id, + start_id, + io_ctx, + }; + + match searcher_registry().lock() { + Ok(mut guard) => guard.insert(searcher), + Err(_) => { let _ = env.throw_new("java/lang/IllegalStateException", "Registry error"); 0 } + } +} + +// ======================== indexSearchWithReader ======================== + +/// Search on a searcher handle created by `indexCreateSearcherFromReaders`. +/// +/// Implements the standard DiskANN search algorithm: +/// +/// 1. **Start** from the medoid (graph entry point). +/// 2. **Beam Search loop**: +/// - Pop the unvisited node with the smallest **PQ distance** from the beam. +/// - **Disk I/O**: read its full vector + neighbor list via JNI. +/// - **Compute Exact**: use the full vector to compute exact distance, update +/// the result heap. +/// - **Expand**: for each neighbor, compute **PQ distance** (in-memory). +/// - **Push**: add neighbors with good PQ distance to the beam (capped at L). +/// 3. **Return** the result heap (already exactly sorted). +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexSearchWithReader<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + n: jlong, + query_vectors: JPrimitiveArray<'local, jfloat>, + k: jint, + search_list_size: jint, + distances: JPrimitiveArray<'local, jfloat>, + labels: JPrimitiveArray<'local, jlong>, +) { + let num = n as usize; + let top_k = k as usize; + + let arc = match get_searcher(handle) { + Some(a) => a, + None => { + let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid searcher handle"); + return; + } + }; + + // Copy query vectors. + let query: Vec = { + let elems = match unsafe { env.get_array_elements(&query_vectors, ReleaseMode::NoCopyBack) } { + Ok(a) => a, + Err(_) => { let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid queries"); return; } + }; + elems.iter().copied().collect() + }; + + let state = match arc.lock() { + Ok(s) => s, + Err(_) => { let _ = env.throw_new("java/lang/IllegalStateException", "Lock poisoned"); return; } + }; + + let dimension = state.dimension as usize; + let total = num * top_k; + let mut result_dist = vec![f32::MAX; total]; + let mut result_lbl = vec![-1i64; total]; + + let provider = state.index.provider(); + let pq = &provider.pq_state; + let io = &state.io_ctx; + let start_id = state.start_id; + + // Attach JNI thread once for all queries. + let mut jni_env = match io.jvm.attach_current_thread() { + Ok(e) => e, + Err(e) => { + let _ = env.throw_new("java/lang/RuntimeException", format!("JVM attach: {}", e)); + return; + } + }; + + for qi in 0..num { + let qvec = &query[qi * dimension..(qi + 1) * dimension]; + let l = std::cmp::max(search_list_size as usize, top_k); + + // ---- Pre-compute PQ distance lookup table for this query ---- + let distance_table = pq.compute_distance_table(qvec, io.metric_type); + + // ---- Beam: sorted candidate list, capped at L entries ---- + // Each entry: (pq_distance, internal_node_id). + // Sorted ascending by pq_distance so beam[0] is always the closest. + let mut beam: Vec<(f32, u32)> = Vec::with_capacity(l + 1); + let mut visited = std::collections::HashSet::::with_capacity(l * 2); + + // ---- Result heap: max-heap of (exact_distance, vec_idx) capped at top_k ---- + let mut results: Vec<(f32, usize)> = Vec::with_capacity(top_k + 1); + let mut result_worst = f32::MAX; + + // Seed beam with the start point (medoid). + // start_id is an internal node ID; vec_idx = start_id - 1. + { + let start_vec_idx = (start_id as usize).wrapping_sub(1); + let start_pq_dist = if start_vec_idx < pq.num_vectors { + pq.adc_distance(start_vec_idx, &distance_table, io.metric_type) + } else { + f32::MAX + }; + beam.push((start_pq_dist, start_id)); + } + + // ---- Beam search loop ---- + loop { + // Find the closest unvisited candidate in the beam. + let next = beam.iter() + .enumerate() + .filter(|(_, (_, id))| !visited.contains(id)) + .min_by(|a, b| a.1.0.partial_cmp(&b.1.0).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, &(dist, id))| (idx, dist, id)); + + let (_beam_idx, _pq_dist, node_id) = match next { + Some(t) => t, + None => break, // No more unvisited candidates — convergence. + }; + + visited.insert(node_id); + + // ---- Disk I/O: read full vector for this node ---- + // (skip start point if it has no data vector — int_id 0 is synthetic) + let vec_idx = (node_id as usize).wrapping_sub(1); + + if node_id != START_POINT_ID && vec_idx < pq.num_vectors { + let position = vec_idx as i64; + let load_ok = jni_env.call_method( + &io.vector_reader_ref, + "loadVector", + "(J)Z", + &[jni::objects::JValue::Long(position)], + ); + if let Ok(v) = load_ok { + if let Ok(true) = v.z() { + let full_vec = unsafe { + std::slice::from_raw_parts(io.single_buf_ptr, io.dim) + }; + let exact_dist = compute_exact_distance(qvec, full_vec, io.metric_type); + + // Update result heap (keep top_k smallest exact distances). + if results.len() < top_k { + results.push((exact_dist, vec_idx)); + if results.len() == top_k { + result_worst = results.iter() + .map(|e| e.0) + .fold(f32::NEG_INFINITY, f32::max); + } + } else if exact_dist < result_worst { + // Replace the worst entry. + if let Some(pos) = results.iter().position(|e| e.0 == result_worst) { + results[pos] = (exact_dist, vec_idx); + result_worst = results.iter() + .map(|e| e.0) + .fold(f32::NEG_INFINITY, f32::max); + } + } + } + } + } + + // ---- Read neighbor list for this node ---- + let neighbors: Vec = { + // Try graph cache first. + if let Some(term) = provider.graph.get(&node_id) { + term.neighbors.iter().copied().collect() + } else { + // Fetch from graph reader via JNI. + let ret = jni_env.call_method( + &io.graph_reader_ref, + "readNeighbors", + "(I)[I", + &[jni::objects::JValue::Int(node_id as i32)], + ); + match ret { + Ok(v) => match v.l() { + Ok(obj) if !obj.is_null() => { + let int_array = jni::objects::JIntArray::from(obj); + let arr_len = jni_env.get_array_length(&int_array).unwrap_or(0) as usize; + let mut buf = vec![0i32; arr_len]; + let _ = jni_env.get_int_array_region(&int_array, 0, &mut buf); + let nbrs: Vec = buf.into_iter().map(|v| v as u32).collect(); + // Cache for future queries. + let adj = diskann::graph::AdjacencyList::from_iter_untrusted(nbrs.iter().copied()); + provider.graph.insert(node_id, paimon_fileio_provider::GraphTerm { neighbors: adj }); + nbrs + } + _ => Vec::new(), + }, + Err(_) => { + let _ = jni_env.exception_clear(); + Vec::new() + } + } + } + }; + + // ---- Expand: compute PQ distance for each neighbor, add to beam ---- + let beam_worst = if beam.len() >= l { + beam.last().map(|e| e.0).unwrap_or(f32::MAX) + } else { + f32::MAX + }; + + for &nbr_id in &neighbors { + if visited.contains(&nbr_id) { + continue; + } + // Already in beam? Skip duplicate. + if beam.iter().any(|&(_, id)| id == nbr_id) { + continue; + } + + let nbr_vec_idx = (nbr_id as usize).wrapping_sub(1); + let nbr_pq_dist = if nbr_vec_idx < pq.num_vectors { + pq.adc_distance(nbr_vec_idx, &distance_table, io.metric_type) + } else { + f32::MAX + }; + + if beam.len() < l || nbr_pq_dist < beam_worst { + // Insert in sorted order. + let insert_pos = beam.partition_point(|e| e.0 < nbr_pq_dist); + beam.insert(insert_pos, (nbr_pq_dist, nbr_id)); + // Trim to L entries. + if beam.len() > l { + beam.truncate(l); + } + } + } + } + + // ---- Collect top-K results (sorted by exact distance) ---- + results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + for (cnt, &(exact_dist, vec_idx)) in results.iter().enumerate() { + if cnt >= top_k { + break; + } + let idx = qi * top_k + cnt; + result_lbl[idx] = state.min_ext_id + vec_idx as i64; + result_dist[idx] = exact_dist; + } + } + + drop(jni_env); + drop(state); + + // Write back distances. + { + let mut de = match unsafe { env.get_array_elements(&distances, ReleaseMode::CopyBack) } { + Ok(a) => a, + Err(_) => { let _ = env.throw_new("java/lang/IllegalArgumentException", "Bad distances"); return; } + }; + for i in 0..std::cmp::min(de.len(), result_dist.len()) { de[i] = result_dist[i]; } + } + // Write back labels. + { + let mut le = match unsafe { env.get_array_elements(&labels, ReleaseMode::CopyBack) } { + Ok(a) => a, + Err(_) => { let _ = env.throw_new("java/lang/IllegalArgumentException", "Bad labels"); return; } + }; + for i in 0..std::cmp::min(le.len(), result_lbl.len()) { le[i] = result_lbl[i]; } + } +} + +/// Destroy a searcher handle. +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_indexDestroySearcher<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) { + jni_catch_unwind(&mut env, (), AssertUnwindSafe(|| { + if let Ok(mut guard) = searcher_registry().lock() { + guard.searchers.remove(&handle); + } + })); +} + +// ======================== PQ Train & Encode ======================== + +/// Train a PQ codebook on the vectors stored in the index and encode all vectors. +/// +/// Uses `diskann-quantization`'s `LightPQTrainingParameters` for K-Means++ / Lloyd +/// training and `BasicTable` for encoding. +/// +/// Returns a `byte[][]` where: +/// `[0]` = serialized pivots (codebook) +/// `[1]` = serialized compressed codes +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_diskann_DiskAnnNative_pqTrainAndEncode<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + num_subspaces: jint, + max_samples: jint, + kmeans_iters: jint, +) -> JObject<'local> { + // Obtain the index state outside of catch_unwind so we can throw typed exceptions. + let arc = match get_index(handle) { + Some(a) => a, + None => { + let _ = env.throw_new("java/lang/IllegalArgumentException", "Invalid index handle"); + return JObject::null(); + } + }; + let state = match arc.lock() { + Ok(s) => s, + Err(_) => { + let _ = env.throw_new("java/lang/IllegalStateException", "Index lock poisoned"); + return JObject::null(); + } + }; + + let dim = state.dimension as usize; + let m = num_subspaces as usize; + let max_s = max_samples as usize; + let iters = kmeans_iters as usize; + + // Perform PQ training and encoding inside catch_unwind to prevent panics crossing FFI. + let pq_result = { + let raw_data = &state.raw_data; + let result = panic::catch_unwind(AssertUnwindSafe(|| { + pq::train_and_encode(raw_data, dim, m, max_s, iters) + })); + // Drop the lock before JNI object creation. + drop(state); + match result { + Ok(Ok(r)) => r, + Ok(Err(msg)) => { + let _ = env.throw_new("java/lang/RuntimeException", msg); + return JObject::null(); + } + Err(payload) => { + let msg = if let Some(s) = payload.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = payload.downcast_ref::() { + s.clone() + } else { + "Unknown Rust panic in PQ training".to_string() + }; + let _ = env.throw_new("java/lang/RuntimeException", msg); + return JObject::null(); + } + } + }; + + // Build Java byte[][] result. + let pivots_array = match env.byte_array_from_slice(&pq_result.pivots_bytes) { + Ok(a) => a, + Err(e) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("Failed to create pivots byte[]: {}", e), + ); + return JObject::null(); + } + }; + let compressed_array = match env.byte_array_from_slice(&pq_result.compressed_bytes) { + Ok(a) => a, + Err(e) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("Failed to create compressed byte[]: {}", e), + ); + return JObject::null(); + } + }; + + let byte_array_class = match env.find_class("[B") { + Ok(c) => c, + Err(e) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("Failed to find [B class: {}", e), + ); + return JObject::null(); + } + }; + + let result = match env.new_object_array(2, &byte_array_class, &JObject::null()) { + Ok(a) => a, + Err(e) => { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("Failed to create byte[][]: {}", e), + ); + return JObject::null(); + } + }; + + if let Err(e) = env.set_object_array_element(&result, 0, &pivots_array) { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("Failed to set pivots: {}", e), + ); + return JObject::null(); + } + if let Err(e) = env.set_object_array_element(&result, 1, &compressed_array) { + let _ = env.throw_new( + "java/lang/RuntimeException", + format!("Failed to set compressed: {}", e), + ); + return JObject::null(); + } + + result.into() +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/native/src/paimon_fileio_provider.rs b/paimon-diskann/paimon-diskann-jni/src/main/native/src/paimon_fileio_provider.rs new file mode 100644 index 000000000000..149a9d83ae69 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/native/src/paimon_fileio_provider.rs @@ -0,0 +1,1042 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! A DiskANN [`DataProvider`] backed by Paimon FileIO (local, HDFS, S3, OSS, etc.). +//! +//! Both the navigational graph (adjacency lists) and full-precision vectors +//! are stored in Paimon FileIO-backed storage and read on demand via JNI +//! callbacks to Java reader objects: +//! +//! - **Graph**: read through `FileIOGraphReader.readNeighbors(int)`, which +//! reads from a `SeekableInputStream` over the `.index` file. +//! - **Vectors**: read through `FileIOVectorReader.loadVector(long)` (zero-copy +//! via DirectByteBuffer) or `readVectorsBatch(long[], int)` (batch prefetch). +//! +//! Performance optimizations: +//! +//! - **Zero-copy vector reads**: `loadVector` writes into a pre-allocated +//! DirectByteBuffer. The Rust side reads floats directly from the native +//! memory address — no `float[]` allocation, no JNI array copy. +//! - **Batch prefetch**: When a node's neighbors are fetched, all neighbor +//! vectors are batch-prefetched in a single JNI call, populating the +//! provider-level `vector_cache`. Subsequent `get_element` calls hit cache. +//! - **Graph cache**: A `DashMap` lazily caches graph entries to reduce +//! repeated FileIO reads. + +use std::collections::HashMap; + +use dashmap::DashMap; +use diskann::graph::glue; +use diskann::graph::AdjacencyList; +use diskann::provider; +use diskann::{ANNError, ANNResult}; +use diskann_vector::distance::Metric; + +use jni::objects::GlobalRef; +use jni::JavaVM; + +use crate::map_metric; + +// ======================== Error ======================== + +#[derive(Debug, Clone)] +pub enum FileIOProviderError { + InvalidId(u32), + JniCallFailed(String), +} + +impl std::fmt::Display for FileIOProviderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidId(id) => write!(f, "invalid vector id {}", id), + Self::JniCallFailed(msg) => write!(f, "FileIO read failed: {}", msg), + } + } +} + +impl std::error::Error for FileIOProviderError {} + +impl From for ANNError { + #[track_caller] + fn from(e: FileIOProviderError) -> ANNError { + ANNError::opaque(e) + } +} + +diskann::always_escalate!(FileIOProviderError); + +// ======================== PQ State ======================== + +/// In-memory Product Quantization state for approximate distance computation. +/// +/// During beam search, PQ-reconstructed vectors replace full-precision disk I/O, +/// making the search almost entirely in-memory. Only the final top-K candidates +/// are re-ranked with full-precision vectors from disk. +#[derive(Debug)] +pub struct PQState { + /// Number of PQ subspaces (M). + pub num_subspaces: usize, + /// Number of centroids per subspace (K). + pub num_centroids: usize, + /// Sub-vector dimension (dimension / M). + pub sub_dim: usize, + /// Full vector dimension. + pub dimension: usize, + /// Centroid data, laid out as: pivots[m * K * sub_dim + k * sub_dim .. + sub_dim]. + pub pivots: Vec, + /// Compressed codes: codes[vec_idx * M + m] = centroid index for vector vec_idx, subspace m. + pub codes: Vec, + /// Number of encoded vectors. + pub num_vectors: usize, +} + +impl PQState { + /// Deserialize PQ pivots and compressed codes from the byte arrays written by `pq.rs`. + /// + /// Pivots format: i32 dim | i32 M | i32 K | i32 sub_dim | f32[M*K*sub_dim] + /// Codes format: i32 N | i32 M | byte[N*M] + pub fn deserialize(pivots_bytes: &[u8], compressed_bytes: &[u8]) -> Result { + if pivots_bytes.len() < 16 { + return Err("PQ pivots too small".into()); + } + if compressed_bytes.len() < 8 { + return Err("PQ compressed too small".into()); + } + + let dimension = i32::from_ne_bytes(pivots_bytes[0..4].try_into().unwrap()) as usize; + let num_subspaces = i32::from_ne_bytes(pivots_bytes[4..8].try_into().unwrap()) as usize; + let num_centroids = i32::from_ne_bytes(pivots_bytes[8..12].try_into().unwrap()) as usize; + let sub_dim = i32::from_ne_bytes(pivots_bytes[12..16].try_into().unwrap()) as usize; + + let expected_pivots_data = num_subspaces * num_centroids * sub_dim * 4; + if pivots_bytes.len() < 16 + expected_pivots_data { + return Err(format!( + "PQ pivots data too small: need {}, have {}", + 16 + expected_pivots_data, + pivots_bytes.len() + )); + } + + // Parse pivots as f32 (native endian). + let pivots: Vec = pivots_bytes[16..16 + expected_pivots_data] + .chunks_exact(4) + .map(|c| f32::from_ne_bytes(c.try_into().unwrap())) + .collect(); + + // Parse compressed header. + let num_vectors = i32::from_ne_bytes(compressed_bytes[0..4].try_into().unwrap()) as usize; + let m_check = i32::from_ne_bytes(compressed_bytes[4..8].try_into().unwrap()) as usize; + if m_check != num_subspaces { + return Err(format!( + "PQ subspace mismatch: pivots M={}, compressed M={}", + num_subspaces, m_check + )); + } + + let expected_codes = num_vectors * num_subspaces; + if compressed_bytes.len() < 8 + expected_codes { + return Err(format!( + "PQ codes too small: need {}, have {}", + 8 + expected_codes, + compressed_bytes.len() + )); + } + + let codes = compressed_bytes[8..8 + expected_codes].to_vec(); + + Ok(Self { + num_subspaces, + num_centroids, + sub_dim, + dimension, + pivots, + codes, + num_vectors, + }) + } + + /// Reconstruct an approximate vector for the given 0-based vector index + /// by looking up PQ centroid sub-vectors. + /// + /// Cost: M table lookups + dimension float copies — entirely in L1/L2 cache. + #[inline] + pub fn reconstruct(&self, vec_idx: usize, out: &mut [f32]) { + debug_assert!(vec_idx < self.num_vectors); + debug_assert!(out.len() >= self.dimension); + let code_base = vec_idx * self.num_subspaces; + for m in 0..self.num_subspaces { + let code = self.codes[code_base + m] as usize; + let src_offset = m * self.num_centroids * self.sub_dim + code * self.sub_dim; + let dst_offset = m * self.sub_dim; + out[dst_offset..dst_offset + self.sub_dim] + .copy_from_slice(&self.pivots[src_offset..src_offset + self.sub_dim]); + } + } + + // ---- ADC (Asymmetric Distance Computation) for brute-force PQ search ---- + + /// Pre-compute a distance table from a query vector to all PQ centroids. + /// + /// Returns `dt[m * K + k]` where: + /// - L2: squared L2 distance between query sub-vector m and centroid (m, k) + /// - IP/Cosine: dot product between query sub-vector m and centroid (m, k) + /// + /// Cost: O(M * K * sub_dim) — computed once per query, amortized over all vectors. + pub fn compute_distance_table(&self, query: &[f32], metric_type: i32) -> Vec { + let m = self.num_subspaces; + let k = self.num_centroids; + let sd = self.sub_dim; + let mut table = vec![0.0f32; m * k]; + for mi in 0..m { + let q_start = mi * sd; + let q_sub = &query[q_start..q_start + sd]; + for ki in 0..k { + let c_off = mi * k * sd + ki * sd; + let centroid = &self.pivots[c_off..c_off + sd]; + let val = if metric_type == 1 || metric_type == 2 { + // IP or Cosine: dot product per subspace. + q_sub.iter().zip(centroid).map(|(a, b)| a * b).sum::() + } else { + // L2: squared L2 per subspace. + q_sub + .iter() + .zip(centroid) + .map(|(a, b)| { + let d = a - b; + d * d + }) + .sum::() + }; + table[mi * k + ki] = val; + } + } + table + } + + /// Compute the approximate PQ distance for one vector using a pre-computed + /// distance table. Cost: O(M) — just M table lookups. + #[inline] + pub fn adc_distance(&self, vec_idx: usize, table: &[f32], metric_type: i32) -> f32 { + let base = vec_idx * self.num_subspaces; + let k = self.num_centroids; + let mut raw = 0.0f32; + for mi in 0..self.num_subspaces { + let code = self.codes[base + mi] as usize; + raw += table[mi * k + code]; + } + // IP/Cosine: negate dot product so that larger similarity = smaller distance. + if metric_type == 1 || metric_type == 2 { + -raw + } else { + raw + } + } + + /// Brute-force PQ search: scan all vectors with ADC, return the top-K + /// closest as `(vec_idx, pq_distance)` sorted by ascending distance. + /// + /// This replaces the graph-based beam search when PQ data is available. + /// The entire search is in-memory — zero graph I/O, zero vector disk I/O. + /// + /// Cost: O(M * K * sub_dim) for the distance table + + /// O(N * M) for scanning all vectors. + pub fn brute_force_search( + &self, + query: &[f32], + top_k: usize, + metric_type: i32, + ) -> Vec<(usize, f32)> { + let table = self.compute_distance_table(query, metric_type); + let k = top_k.min(self.num_vectors); + if k == 0 { + return Vec::new(); + } + + // Max-heap to keep the top-k smallest distances. + // We store (distance, vec_idx) and the heap evicts the largest distance. + let mut heap: Vec<(f32, usize)> = Vec::with_capacity(k); + let mut heap_max = f32::MAX; + + for i in 0..self.num_vectors { + let dist = self.adc_distance(i, &table, metric_type); + if heap.len() < k { + heap.push((dist, i)); + if heap.len() == k { + // Find the current maximum after filling the heap. + heap_max = heap.iter().map(|e| e.0).fold(f32::NEG_INFINITY, f32::max); + } + } else if dist < heap_max { + // Replace the worst element. + if let Some(pos) = heap.iter().position(|e| e.0 == heap_max) { + heap[pos] = (dist, i); + heap_max = heap.iter().map(|e| e.0).fold(f32::NEG_INFINITY, f32::max); + } + } + } + + // Sort by ascending distance. + heap.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + heap.into_iter().map(|(d, i)| (i, d)).collect() + } +} + +// ======================== Graph Term ======================== + +/// One entry in the graph cache: its neighbor list. +pub struct GraphTerm { + pub neighbors: AdjacencyList, +} + +// ======================== FileIOProvider ======================== + +/// DiskANN data provider backed by Paimon FileIO. +/// +/// Graph neighbors and vectors are read on demand from FileIO-backed storage +/// (local, HDFS, S3, OSS, etc.) via JNI callbacks to Java reader objects. +/// +/// Three levels of vector access (in priority order): +/// 1. **PQ reconstruction** (in-memory, ~O(dim) CPU, no I/O) — used during +/// beam search when PQ data is available. +/// 2. **Provider-level `vector_cache`** (exact vectors cached from reranking +/// or disk I/O). +/// 3. **DirectByteBuffer disk I/O** (single JNI call + zero-copy read). +/// +/// Graph neighbors are cached in a `DashMap` (lazy, write-once). +pub struct FileIOProvider { + /// Graph cache: internal_id → { neighbors }. + pub graph: DashMap, + /// Provider-level vector cache (exact vectors from reranking / disk reads). + vector_cache: DashMap>, + /// Total number of nodes (start point + user vectors). + num_nodes: usize, + /// Start-point IDs and their vectors (always kept in memory). + start_points: HashMap>, + /// JVM handle for attaching threads. + jvm: JavaVM, + /// Global reference to the Java vector reader object (`FileIOVectorReader`). + reader_ref: GlobalRef, + /// Global reference to the Java graph reader object (`FileIOGraphReader`). + graph_reader_ref: Option, + /// Vector dimension. + dim: usize, + /// Distance metric. + metric: Metric, + /// Max degree. + max_degree: usize, + /// Native memory address of the single-vector DirectByteBuffer. + single_buf_ptr: *mut f32, + /// Native memory address of the batch DirectByteBuffer. + batch_buf_ptr: *mut f32, + /// Maximum number of vectors in one batch read. + max_batch_size: usize, + /// PQ state for in-memory approximate distance computation during beam search. + /// PQ is always present — Java validates this before creating the native searcher. + pub pq_state: PQState, +} + +impl std::fmt::Debug for FileIOProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileIOProvider") + .field("dim", &self.dim) + .field("metric", &self.metric) + .field("max_degree", &self.max_degree) + .field("graph_size", &self.graph.len()) + .field("vector_cache_size", &self.vector_cache.len()) + .field("pq_enabled", &true) + .finish() + } +} + +// SAFETY: JavaVM is Send+Sync, GlobalRef is Send+Sync. +// Raw pointers are stable (backed by Java DirectByteBuffer kept alive by GlobalRef). +// All access is serialized by single-threaded tokio runtime. +unsafe impl Send for FileIOProvider {} +unsafe impl Sync for FileIOProvider {} + +impl FileIOProvider { + /// Build a search-only provider with on-demand graph reading, zero-copy + /// vector access, and PQ for in-memory approximate search. + /// + /// `single_buf_ptr` and `batch_buf_ptr` are native addresses obtained via + /// JNI `GetDirectBufferAddress` on the Java reader's DirectByteBuffers. + /// PQ is always required — Java validates this before creating the native searcher. + #[allow(clippy::too_many_arguments)] + pub fn new_with_readers( + num_nodes: usize, + start_id: u32, + start_vec: Vec, + jvm: JavaVM, + reader_ref: GlobalRef, + graph_reader_ref: GlobalRef, + dim: usize, + metric_type: i32, + max_degree: usize, + single_buf_ptr: *mut f32, + batch_buf_ptr: *mut f32, + max_batch_size: usize, + pq_state: PQState, + ) -> Self { + let graph = DashMap::new(); + let vector_cache = DashMap::new(); + + let mut start_points = HashMap::new(); + start_points.insert(start_id, start_vec); + + Self { + graph, + vector_cache, + num_nodes, + start_points, + jvm, + reader_ref, + graph_reader_ref: Some(graph_reader_ref), + dim, + metric: map_metric(metric_type), + max_degree, + single_buf_ptr, + batch_buf_ptr, + max_batch_size, + pq_state, + } + } + + /// PQ brute-force search: scan all PQ codes, return top-K candidates. + /// + /// Returns `vec of (vec_idx, pq_distance)`. + /// `vec_idx` is 0-based (data file position); `int_id = vec_idx + 1`. + pub fn pq_brute_force_search( + &self, + query: &[f32], + top_k: usize, + ) -> Vec<(usize, f32)> { + let metric_type = match self.metric { + Metric::InnerProduct => 1i32, + Metric::Cosine => 2i32, + _ => 0i32, + }; + self.pq_state.brute_force_search(query, top_k, metric_type) + } + + pub fn dim(&self) -> usize { + self.dim + } + + pub fn metric(&self) -> Metric { + self.metric + } + + // ---- Graph I/O ---- + + /// Fetch neighbor list from FileIO-backed storage via JNI callback to + /// `graphReader.readNeighbors(int)`. + fn fetch_neighbors(&self, int_id: u32) -> Result>, FileIOProviderError> { + let graph_ref = match &self.graph_reader_ref { + Some(r) => r, + None => return Ok(None), + }; + + let mut env = self + .jvm + .attach_current_thread() + .map_err(|e| FileIOProviderError::JniCallFailed(format!("attach failed: {}", e)))?; + + let result = env.call_method( + graph_ref, + "readNeighbors", + "(I)[I", + &[jni::objects::JValue::Int(int_id as i32)], + ); + + let ret_val = match result { + Ok(v) => v, + Err(e) => { + let _ = env.exception_clear(); + return Err(FileIOProviderError::JniCallFailed(format!( + "readNeighbors({}) failed: {}", + int_id, e + ))); + } + }; + + let obj = match ret_val.l() { + Ok(o) => o, + Err(_) => return Ok(Some(Vec::new())), + }; + + if obj.is_null() { + return Ok(Some(Vec::new())); + } + + let int_array = jni::objects::JIntArray::from(obj); + let len = env + .get_array_length(&int_array) + .map_err(|e| FileIOProviderError::JniCallFailed(format!("get_array_length: {}", e)))? + as usize; + + let mut buf = vec![0i32; len]; + env.get_int_array_region(&int_array, 0, &mut buf) + .map_err(|e| FileIOProviderError::JniCallFailed(format!("get_int_array_region: {}", e)))?; + + Ok(Some(buf.into_iter().map(|v| v as u32).collect())) + } + + // ---- Vector I/O (zero-copy via DirectByteBuffer) ---- + + /// Fetch a single vector via `loadVector(long)` and read from DirectByteBuffer. + /// + /// The Java method writes the vector into the pre-allocated DirectByteBuffer. + /// We then read floats directly from the native address — no `float[]` + /// allocation and no JNI array copy. + fn fetch_vector(&self, position: i64) -> Result>, FileIOProviderError> { + let mut env = self + .jvm + .attach_current_thread() + .map_err(|e| FileIOProviderError::JniCallFailed(format!("attach failed: {}", e)))?; + + let result = env.call_method( + &self.reader_ref, + "loadVector", + "(J)Z", + &[jni::objects::JValue::Long(position)], + ); + + let success = match result { + Ok(v) => match v.z() { + Ok(b) => b, + Err(_) => false, + }, + Err(e) => { + let _ = env.exception_clear(); + return Err(FileIOProviderError::JniCallFailed(format!( + "loadVector({}) failed: {}", + position, e + ))); + } + }; + + if !success { + return Ok(None); + } + + // Read floats directly from the DirectByteBuffer native address. + // SAFETY: single_buf_ptr is valid (backed by Java DirectByteBuffer kept alive + // by GlobalRef), and access is serialized (single-threaded tokio runtime). + let vec = unsafe { + let slice = std::slice::from_raw_parts(self.single_buf_ptr, self.dim); + slice.to_vec() + }; + + Ok(Some(vec)) + } + + /// Batch-prefetch vectors into the provider-level `vector_cache`. + /// + /// Calls `readVectorsBatch(long[], int)` once via JNI, then reads all vectors + /// from the batch DirectByteBuffer native address. Each vector is inserted + /// into `vector_cache` keyed by its internal node ID. + /// + /// `ids` contains internal node IDs (not positions). Position = id − 1. + fn prefetch_vectors(&self, ids: &[u32]) -> Result<(), FileIOProviderError> { + if ids.is_empty() || self.batch_buf_ptr.is_null() { + return Ok(()); + } + + let count = std::cmp::min(ids.len(), self.max_batch_size); + + let mut env = self + .jvm + .attach_current_thread() + .map_err(|e| FileIOProviderError::JniCallFailed(format!("attach failed: {}", e)))?; + + // Build Java long[] of positions (position = int_id − 1). + let positions: Vec = ids[..count].iter().map(|&id| (id as i64) - 1).collect(); + let java_positions = env + .new_long_array(count as i32) + .map_err(|e| FileIOProviderError::JniCallFailed(format!("new_long_array: {}", e)))?; + env.set_long_array_region(&java_positions, 0, &positions) + .map_err(|e| FileIOProviderError::JniCallFailed(format!("set_long_array_region: {}", e)))?; + + // Single JNI call: readVectorsBatch(long[], int) → int + // SAFETY: JLongArray wraps a JObject; we reinterpret the raw pointer. + let positions_obj = unsafe { + jni::objects::JObject::from_raw(java_positions.as_raw()) + }; + let result = env.call_method( + &self.reader_ref, + "readVectorsBatch", + "([JI)I", + &[ + jni::objects::JValue::Object(&positions_obj), + jni::objects::JValue::Int(count as i32), + ], + ); + // Prevent double-free: positions_obj shares the raw handle with java_positions. + std::mem::forget(positions_obj); + + let read_count = match result { + Ok(v) => match v.i() { + Ok(n) => n as usize, + Err(_) => 0, + }, + Err(e) => { + let _ = env.exception_clear(); + return Err(FileIOProviderError::JniCallFailed(format!( + "readVectorsBatch failed: {}", + e + ))); + } + }; + + // Read vectors from batch DirectByteBuffer native address and populate cache. + // SAFETY: batch_buf_ptr is valid, access is serialized. + for i in 0..read_count { + let int_id = ids[i]; + if self.vector_cache.contains_key(&int_id) { + continue; // already cached + } + let offset = i * self.dim; + let vec = unsafe { + let slice = std::slice::from_raw_parts(self.batch_buf_ptr.add(offset), self.dim); + slice.to_vec() + }; + self.vector_cache.insert(int_id, vec.into_boxed_slice()); + } + + Ok(()) + } +} + +// ======================== Context ======================== + +/// Lightweight execution context for the FileIO provider. +#[derive(Debug, Clone, Default)] +pub struct FileIOContext; + +impl std::fmt::Display for FileIOContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "paimon fileio context") + } +} + +impl provider::ExecutionContext for FileIOContext {} + +// ======================== DataProvider ======================== + +impl provider::DataProvider for FileIOProvider { + type Context = FileIOContext; + type InternalId = u32; + type ExternalId = u32; + type Error = FileIOProviderError; + + fn to_internal_id( + &self, + _context: &FileIOContext, + gid: &u32, + ) -> Result { + if (*gid as usize) < self.num_nodes { + Ok(*gid) + } else { + Err(FileIOProviderError::InvalidId(*gid)) + } + } + + fn to_external_id( + &self, + _context: &FileIOContext, + id: u32, + ) -> Result { + if (id as usize) < self.num_nodes { + Ok(id) + } else { + Err(FileIOProviderError::InvalidId(id)) + } + } +} + +// ======================== SetElement (stub — search only) ======================== + +impl provider::SetElement<[f32]> for FileIOProvider { + type SetError = ANNError; + type Guard = provider::NoopGuard; + + async fn set_element( + &self, + _context: &FileIOContext, + _id: &u32, + _element: &[f32], + ) -> Result { + Err(ANNError::opaque(FileIOProviderError::JniCallFailed( + "set_element not supported on search-only FileIOProvider".to_string(), + ))) + } +} + +// ======================== NeighborAccessor ======================== + +#[derive(Debug, Clone, Copy)] +pub struct FileIONeighborAccessor<'a> { + provider: &'a FileIOProvider, +} + +impl provider::HasId for FileIONeighborAccessor<'_> { + type Id = u32; +} + +impl provider::NeighborAccessor for FileIONeighborAccessor<'_> { + async fn get_neighbors( + self, + id: Self::Id, + neighbors: &mut AdjacencyList, + ) -> ANNResult { + // 1. Try cached graph. + if let Some(term) = self.provider.graph.get(&id) { + neighbors.overwrite_trusted(&term.neighbors); + // Batch-prefetch neighbor vectors that aren't cached yet. + self.prefetch_neighbor_vectors(&term.neighbors); + return Ok(self); + } + + // 2. On-demand: fetch from FileIO-backed storage via graph reader JNI callback. + if self.provider.graph_reader_ref.is_some() { + let fetched = self.provider.fetch_neighbors(id)?; + if let Some(neighbor_ids) = fetched { + let adj = AdjacencyList::from_iter_untrusted(neighbor_ids.iter().copied()); + neighbors.overwrite_trusted(&adj); + // Cache in the DashMap for subsequent accesses. + self.provider.graph.insert(id, GraphTerm { neighbors: adj.clone() }); + // Batch-prefetch neighbor vectors. + self.prefetch_neighbor_vectors(&adj); + return Ok(self); + } + } + + Err(ANNError::opaque(FileIOProviderError::InvalidId(id))) + } +} + +impl FileIONeighborAccessor<'_> { + /// Batch-prefetch vectors for neighbors that aren't already in the vector cache. + /// + /// PQ is always enabled — beam search uses PQ reconstruction (in-memory), + /// so disk prefetch is a no-op. This method is kept for the graph::Strategy + /// trait requirement. + fn prefetch_neighbor_vectors(&self, _adj: &AdjacencyList) { + // No-op: PQ is always enabled, beam search uses in-memory PQ codes. + } +} + +impl provider::NeighborAccessorMut for FileIONeighborAccessor<'_> { + async fn set_neighbors(self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult { + match self.provider.graph.get_mut(&id) { + Some(mut term) => { + term.neighbors.clear(); + term.neighbors.extend_from_slice(neighbors); + Ok(self) + } + None => Err(ANNError::opaque(FileIOProviderError::InvalidId(id))), + } + } + + async fn append_vector(self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult { + match self.provider.graph.get_mut(&id) { + Some(mut term) => { + term.neighbors.extend_from_slice(neighbors); + Ok(self) + } + None => Err(ANNError::opaque(FileIOProviderError::InvalidId(id))), + } + } +} + +// ======================== DefaultAccessor ======================== + +impl provider::DefaultAccessor for FileIOProvider { + type Accessor<'a> = FileIONeighborAccessor<'a>; + + fn default_accessor(&self) -> Self::Accessor<'_> { + FileIONeighborAccessor { provider: self } + } +} + +// ======================== Accessor ======================== + +/// Accessor that fetches vectors with three cache levels: +/// +/// 1. **Start-point** (always in memory) +/// 2. **Provider-level `vector_cache`** (populated by batch prefetch) +/// 3. **Per-search LRU cache** (local to this accessor) +/// 4. **DirectByteBuffer I/O** (fallback: single JNI call + zero-copy read) +pub struct FileIOAccessor<'a> { + provider: &'a FileIOProvider, + buffer: Box<[f32]>, + cache: VectorCache, +} + +impl<'a> FileIOAccessor<'a> { + pub fn new(provider: &'a FileIOProvider, cache_size: usize) -> Self { + let buffer = vec![0.0f32; provider.dim()].into_boxed_slice(); + Self { + provider, + buffer, + cache: VectorCache::new(cache_size), + } + } +} + +impl provider::HasId for FileIOAccessor<'_> { + type Id = u32; +} + +impl provider::Accessor for FileIOAccessor<'_> { + type Extended = Box<[f32]>; + type Element<'a> = &'a [f32] where Self: 'a; + type ElementRef<'a> = &'a [f32]; + type GetError = FileIOProviderError; + + async fn get_element( + &mut self, + id: u32, + ) -> Result, Self::GetError> { + // 1. Start-point vectors are always in memory. + if let Some(vec) = self.provider.start_points.get(&id) { + self.buffer.copy_from_slice(vec); + return Ok(&*self.buffer); + } + + // 2. Provider-level vector cache (exact vectors from reranking / disk). + if let Some(cached) = self.provider.vector_cache.get(&id) { + self.buffer.copy_from_slice(&cached); + return Ok(&*self.buffer); + } + + // 3. Per-search LRU cache (exact vectors). + if let Some(cached) = self.cache.get(id) { + self.buffer.copy_from_slice(cached); + return Ok(&*self.buffer); + } + + // 4. PQ reconstruction: approximate vector, entirely in-memory, no disk I/O. + // During beam search this is the primary hot path. + { + let pq = &self.provider.pq_state; + let vec_idx = (id as usize).wrapping_sub(1); + if vec_idx < pq.num_vectors { + pq.reconstruct(vec_idx, &mut self.buffer); + return Ok(&*self.buffer); + } + } + + // 5. Fallback: fetch via DirectByteBuffer I/O (single JNI call + zero-copy read). + let position = (id as i64) - 1; + let fetched = self.provider.fetch_vector(position)?; + + match fetched { + Some(vec) if vec.len() == self.provider.dim() => { + self.buffer.copy_from_slice(&vec); + self.provider + .vector_cache + .insert(id, vec.clone().into_boxed_slice()); + self.cache.put(id, vec.into_boxed_slice()); + Ok(&*self.buffer) + } + Some(vec) => Err(FileIOProviderError::JniCallFailed(format!( + "loadVector({}) returned {} floats, expected {}", + position, + vec.len(), + self.provider.dim() + ))), + None => Err(FileIOProviderError::InvalidId(id)), + } + } +} + +// ======================== DelegateNeighbor ======================== + +impl<'this> provider::DelegateNeighbor<'this> for FileIOAccessor<'_> { + type Delegate = FileIONeighborAccessor<'this>; + + fn delegate_neighbor(&'this mut self) -> Self::Delegate { + FileIONeighborAccessor { + provider: self.provider, + } + } +} + +// ======================== BuildQueryComputer ======================== + +impl provider::BuildQueryComputer<[f32]> for FileIOAccessor<'_> { + type QueryComputerError = diskann::error::Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(::query_distance( + from, + self.provider.metric(), + )) + } +} + +// ======================== BuildDistanceComputer ======================== + +impl provider::BuildDistanceComputer for FileIOAccessor<'_> { + type DistanceComputerError = diskann::error::Infallible; + type DistanceComputer = ::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(::distance( + self.provider.metric(), + Some(self.provider.dim()), + )) + } +} + +// ======================== SearchExt ======================== + +impl glue::SearchExt for FileIOAccessor<'_> { + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + futures_util::future::ok(self.provider.start_points.keys().copied().collect()) + } +} + +// ======================== Blanket traits ======================== + +impl glue::ExpandBeam<[f32]> for FileIOAccessor<'_> {} +impl glue::FillSet for FileIOAccessor<'_> {} + +// ======================== Strategy ======================== + +/// Search-only strategy for the Paimon FileIO provider. +#[derive(Debug, Default, Clone, Copy)] +pub struct FileIOStrategy; + +impl FileIOStrategy { + pub fn new() -> Self { + Self + } +} + +impl glue::SearchStrategy for FileIOStrategy { + type QueryComputer = ::QueryDistance; + type PostProcessor = glue::CopyIds; + type SearchAccessorError = diskann::error::Infallible; + type SearchAccessor<'a> = FileIOAccessor<'a>; + + fn search_accessor<'a>( + &'a self, + provider: &'a FileIOProvider, + _context: &'a FileIOContext, + ) -> Result, diskann::error::Infallible> { + // Per-search LRU as last-resort cache (most hits come from vector_cache). + Ok(FileIOAccessor::new(provider, 1024)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +// For insert (graph construction) — delegates to prune/search accessors. +// We implement InsertStrategy and PruneStrategy as stubs since the FileIOProvider +// is search-only. DiskANNIndex::new() requires the Provider to be Sized but +// does NOT call insert methods unless we invoke index.insert(). +impl glue::PruneStrategy for FileIOStrategy { + type DistanceComputer = ::Distance; + type PruneAccessor<'a> = FileIOAccessor<'a>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a FileIOProvider, + _context: &'a FileIOContext, + ) -> Result, Self::PruneAccessorError> { + Ok(FileIOAccessor::new(provider, 1024)) + } +} + +impl glue::InsertStrategy for FileIOStrategy { + type PruneStrategy = Self; + + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a FileIOProvider, + _context: &'a FileIOContext, + ) -> Result, Self::SearchAccessorError> { + Ok(FileIOAccessor::new(provider, 1024)) + } +} + +impl<'a> glue::AsElement<&'a [f32]> for FileIOAccessor<'a> { + type Error = diskann::error::Infallible; + fn as_element( + &mut self, + vector: &'a [f32], + _id: Self::Id, + ) -> impl std::future::Future, Self::Error>> + Send { + std::future::ready(Ok(vector)) + } +} + +// ======================== VectorCache (per-search LRU) ======================== + +/// Tiny LRU cache for per-search vector access. Most hits should come from the +/// provider-level `vector_cache`; this is a last-resort fallback. +struct VectorCache { + map: HashMap>, + order: Vec, + capacity: usize, +} + +impl VectorCache { + fn new(capacity: usize) -> Self { + Self { + map: HashMap::with_capacity(capacity), + order: Vec::with_capacity(capacity), + capacity, + } + } + + fn get(&self, id: u32) -> Option<&[f32]> { + self.map.get(&id).map(|v| &**v) + } + + fn put(&mut self, id: u32, vec: Box<[f32]>) { + if self.map.contains_key(&id) { + return; + } + if self.order.len() >= self.capacity { + if let Some(evicted) = self.order.first().copied() { + self.order.remove(0); + self.map.remove(&evicted); + } + } + self.order.push(id); + self.map.insert(id, vec); + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/main/native/src/pq.rs b/paimon-diskann/paimon-diskann-jni/src/main/native/src/pq.rs new file mode 100644 index 000000000000..8dfb5cb7e3df --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/main/native/src/pq.rs @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Product Quantization (PQ) for DiskANN, backed by `diskann-quantization`. +//! +//! PQ compresses high-dimensional vectors into compact codes by: +//! 1. Splitting each vector into `M` sub-vectors (subspaces). +//! 2. Training `K` centroids per subspace using K-Means clustering +//! (via `diskann-quantization`'s `LightPQTrainingParameters`). +//! 3. Encoding each sub-vector as the index of its nearest centroid (1 byte for K≤256). + +use diskann_quantization::cancel::DontCancel; +use diskann_quantization::product::train::{LightPQTrainingParameters, TrainQuantizer}; +use diskann_quantization::product::BasicTable; +use diskann_quantization::random::StdRngBuilder; +use diskann_quantization::views::ChunkOffsetsView; +use diskann_quantization::{CompressInto, Parallelism}; +use diskann_utils::views::Matrix; + +/// Maximum centroids per subspace. Fixed at 256 so each code fits in one byte (u8). +const NUM_CENTROIDS: usize = 256; + +/// Result of PQ training and encoding. +#[derive(Debug)] +pub struct PQResult { + /// Serialized PQ codebook (pivots). + pub pivots_bytes: Vec, + /// Serialized compressed PQ codes. + pub compressed_bytes: Vec, +} + +/// Train a PQ codebook and encode all vectors using `diskann-quantization`. +/// +/// * `vectors` — training vectors (stored sequentially, position = ID). +/// * `dimension` — vector dimension (must be divisible by `num_subspaces`). +/// * `num_subspaces` — number of PQ subspaces (M). +/// * `max_samples` — maximum number of vectors sampled for training. +/// * `kmeans_iters` — number of Lloyd iterations for K-Means. +/// +/// Returns serialized pivots and compressed codes. +pub fn train_and_encode( + vectors: &[Vec], + dimension: usize, + num_subspaces: usize, + max_samples: usize, + kmeans_iters: usize, +) -> Result { + if dimension == 0 || num_subspaces == 0 || dimension % num_subspaces != 0 { + return Err(format!( + "Invalid PQ params: dim={}, num_subspaces={}", + dimension, num_subspaces + )); + } + + let n = vectors.len(); + if n == 0 { + return Err("No vectors to train PQ".to_string()); + } + + let sub_dim = dimension / num_subspaces; + let k = std::cmp::min(NUM_CENTROIDS, n); + + // --- Build training data as a Matrix (nrows=num_samples, ncols=dim) --- + let sample_n = std::cmp::min(n, max_samples); + let sample_indices = if n > max_samples { + sample_indices_det(n, max_samples) + } else { + (0..n).collect() + }; + + let mut training_data = Matrix::new(0.0f32, sample_n, dimension); + for (dst_row, &src_idx) in sample_indices.iter().enumerate() { + training_data + .row_mut(dst_row) + .copy_from_slice(&vectors[src_idx]); + } + + // --- Build chunk offsets (uniform subspaces) --- + // e.g. for dim=8, M=2: offsets = [0, 4, 8] + let offsets: Vec = (0..=num_subspaces).map(|i| i * sub_dim).collect(); + let schema = ChunkOffsetsView::new(&offsets) + .map_err(|e| format!("Failed to create PQ chunk offsets: {}", e))?; + + // --- Train using diskann-quantization --- + let trainer = LightPQTrainingParameters::new(k, kmeans_iters); + let rng_builder = StdRngBuilder::new(42); + + let quantizer = trainer + .train( + training_data.as_view(), + schema, + Parallelism::Sequential, + &rng_builder, + &DontCancel, + ) + .map_err(|e| format!("PQ training failed: {}", e))?; + + // --- Build BasicTable for encoding --- + let ncenters = quantizer.pivots()[0].nrows(); + let flat_pivots: Vec = quantizer.flatten(); + let pivots_matrix = Matrix::try_from(flat_pivots.into_boxed_slice(), ncenters, dimension) + .map_err(|e| format!("Failed to create pivot matrix: {}", e))?; + + let offsets_owned = schema.to_owned(); + let table = BasicTable::new(pivots_matrix, offsets_owned) + .map_err(|e| format!("Failed to create BasicTable: {}", e))?; + + // --- Encode all vectors --- + let mut all_codes: Vec> = Vec::with_capacity(n); + for vec in vectors.iter() { + let mut code = vec![0u8; num_subspaces]; + table + .compress_into(vec.as_slice(), &mut code) + .map_err(|e| format!("PQ compression failed: {}", e))?; + all_codes.push(code); + } + + // --- Serialize --- + let pivots_bytes = serialize_pivots(&table, dimension, num_subspaces, ncenters, sub_dim); + let compressed_bytes = serialize_compressed(&all_codes, num_subspaces); + + Ok(PQResult { + pivots_bytes, + compressed_bytes, + }) +} + +/// Serialize the PQ codebook (pivots) to bytes. +/// +/// Format (native byte order): +/// ```text +/// i32: dimension +/// i32: num_subspaces (M) +/// i32: num_centroids (K) +/// i32: sub_dimension +/// f32[M * K * sub_dim]: centroid data (stored per-subspace) +/// ``` +fn serialize_pivots( + table: &BasicTable, + dimension: usize, + num_subspaces: usize, + num_centroids: usize, + sub_dim: usize, +) -> Vec { + // The pivots in BasicTable are stored row-major: ncenters rows × dim columns. + // Each row has all subspaces concatenated. + // We need to serialize in the per-subspace format expected by the reader: + // for subspace m: for centroid k: float[sub_dim] + let header_size = 4 * 4; + let data_size = num_subspaces * num_centroids * sub_dim * 4; + let mut buf = Vec::with_capacity(header_size + data_size); + + buf.extend_from_slice(&(dimension as i32).to_ne_bytes()); + buf.extend_from_slice(&(num_subspaces as i32).to_ne_bytes()); + buf.extend_from_slice(&(num_centroids as i32).to_ne_bytes()); + buf.extend_from_slice(&(sub_dim as i32).to_ne_bytes()); + + let pivots_view = table.view_pivots(); + // Reorder from row-major (centroid × full_dim) to subspace-major: + // subspace m, centroid k → row k, columns [m*sub_dim .. (m+1)*sub_dim] + for m in 0..num_subspaces { + let col_start = m * sub_dim; + for k in 0..num_centroids { + let row = pivots_view.row(k); + for d in 0..sub_dim { + buf.extend_from_slice(&row[col_start + d].to_ne_bytes()); + } + } + } + + buf +} + +/// Serialize compressed PQ codes to bytes. +/// +/// Format (native byte order): +/// ```text +/// i32: num_vectors (N) +/// i32: num_subspaces (M) +/// byte[N * M]: PQ codes +/// ``` +fn serialize_compressed(codes: &[Vec], num_subspaces: usize) -> Vec { + let header_size = 4 * 2; + let data_size = codes.len() * num_subspaces; + let mut buf = Vec::with_capacity(header_size + data_size); + + buf.extend_from_slice(&(codes.len() as i32).to_ne_bytes()); + buf.extend_from_slice(&(num_subspaces as i32).to_ne_bytes()); + + for code in codes { + buf.extend_from_slice(code); + } + buf +} + +/// Deterministic sampling without replacement (Fisher-Yates on indices). +fn sample_indices_det(n: usize, sample_size: usize) -> Vec { + let mut indices: Vec = (0..n).collect(); + let mut rng = SimpleRng::new(42); + + let m = std::cmp::min(sample_size, n); + for i in 0..m { + let j = i + rng.next_usize(n - i); + indices.swap(i, j); + } + indices.truncate(m); + indices +} + +/// Minimal deterministic PRNG (xorshift64). +struct SimpleRng { + state: u64, +} + +impl SimpleRng { + fn new(seed: u64) -> Self { + Self { + state: if seed == 0 { 1 } else { seed }, + } + } + + fn next_u64(&mut self) -> u64 { + let mut x = self.state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.state = x; + x + } + + fn next_usize(&mut self, bound: usize) -> usize { + (self.next_u64() % bound as u64) as usize + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pq_train_encode_roundtrip() { + let dim = 8; + let num_subspaces = 2; + let n = 50; + + // Generate some simple vectors. + let vectors: Vec> = (0..n) + .map(|i| { + (0..dim).map(|d| (i * dim + d) as f32 * 0.01).collect() + }) + .collect(); + + let result = train_and_encode(&vectors, dim, num_subspaces, 100, 5).unwrap(); + + // Check pivots serialization. + assert!(result.pivots_bytes.len() > 16); + + // Verify header: dim=8, M=2, K=min(256,50)=50, sub_dim=4 + let dim_read = i32::from_ne_bytes(result.pivots_bytes[0..4].try_into().unwrap()); + let m_read = i32::from_ne_bytes(result.pivots_bytes[4..8].try_into().unwrap()); + let k_read = i32::from_ne_bytes(result.pivots_bytes[8..12].try_into().unwrap()); + let sub_dim_read = i32::from_ne_bytes(result.pivots_bytes[12..16].try_into().unwrap()); + assert_eq!(dim_read, 8); + assert_eq!(m_read, 2); + assert_eq!(k_read, 50); + assert_eq!(sub_dim_read, 4); + + // Check expected pivots size: 16 header + 2*50*4*4 = 16 + 1600 = 1616 + assert_eq!(result.pivots_bytes.len(), 16 + 2 * 50 * 4 * 4); + + // Check compressed serialization. + assert!(result.compressed_bytes.len() > 8); + + // Verify compressed header: N=50, M=2 + let n_read = i32::from_ne_bytes(result.compressed_bytes[0..4].try_into().unwrap()); + let m_comp_read = i32::from_ne_bytes(result.compressed_bytes[4..8].try_into().unwrap()); + assert_eq!(n_read, 50); + assert_eq!(m_comp_read, 2); + + // Check expected compressed size: 8 header + 50*2 = 108 + assert_eq!(result.compressed_bytes.len(), 8 + 50 * 2); + } + + #[test] + fn test_pq_invalid_params() { + let vectors: Vec> = vec![vec![1.0, 2.0, 3.0, 4.0]]; + + // dim not divisible by num_subspaces + let err = train_and_encode(&vectors, 4, 3, 100, 5).unwrap_err(); + assert!(err.contains("Invalid PQ params")); + + // num_subspaces = 0 + let err = train_and_encode(&vectors, 4, 0, 100, 5).unwrap_err(); + assert!(err.contains("Invalid PQ params")); + + // empty vectors + let err = train_and_encode(&[], 4, 2, 100, 5).unwrap_err(); + assert!(err.contains("No vectors")); + } +} diff --git a/paimon-diskann/paimon-diskann-jni/src/test/java/org/apache/paimon/diskann/IndexTest.java b/paimon-diskann/paimon-diskann-jni/src/test/java/org/apache/paimon/diskann/IndexTest.java new file mode 100644 index 000000000000..7fb44ae56dd3 --- /dev/null +++ b/paimon-diskann/paimon-diskann-jni/src/test/java/org/apache/paimon/diskann/IndexTest.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.diskann; + +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for the DiskANN Index class. + * + *

Note: These tests require the native library to be built and available. They will be skipped + * if the native library is not found. + */ +class IndexTest { + + private static final int DIMENSION = 128; + private static final int NUM_VECTORS = 1000; + private static final int K = 10; + private static final int MAX_DEGREE = 64; + private static final int BUILD_LIST_SIZE = 100; + private static final int SEARCH_LIST_SIZE = 100; + private static final int INDEX_TYPE_MEMORY = 0; + + @BeforeAll + static void checkNativeLibrary() { + if (!DiskAnn.isLibraryLoaded()) { + try { + DiskAnn.loadLibrary(); + } catch (DiskAnnException e) { + StringBuilder errorMsg = new StringBuilder("DiskANN native library not available."); + errorMsg.append("\nError: ").append(e.getMessage()); + if (e.getCause() != null) { + errorMsg.append("\nCause: ").append(e.getCause().getMessage()); + } + errorMsg.append( + "\n\nTo run DiskANN tests, ensure the paimon-diskann-jni JAR" + + " with native libraries is available in the classpath."); + Assumptions.assumeTrue(false, errorMsg.toString()); + } + } + } + + @Test + void testBasicOperations() { + try (Index index = createIndex(MetricType.L2)) { + assertEquals(DIMENSION, index.getDimension()); + assertEquals(0, index.getCount()); + assertEquals(MetricType.L2, index.getMetricType()); + + // Add vectors with IDs + addVectors(index, NUM_VECTORS, DIMENSION); + assertEquals(NUM_VECTORS, index.getCount()); + + // Build the index + index.build(BUILD_LIST_SIZE); + + // Search + float[] queryVectors = createQueryVectors(1, DIMENSION); + float[] distances = new float[K]; + long[] labels = new long[K]; + + index.search(1, queryVectors, K, SEARCH_LIST_SIZE, distances, labels); + + // Verify labels are in valid range + for (int i = 0; i < K; i++) { + assertTrue( + labels[i] >= 0 && labels[i] < NUM_VECTORS, + "Label " + labels[i] + " out of range"); + } + + // Verify distances are non-negative for L2 + for (int i = 0; i < K; i++) { + assertTrue(distances[i] >= 0, "Distance should be non-negative for L2"); + } + } + } + + @Test + void testSequentialIds() { + try (Index index = createIndex(MetricType.L2)) { + addVectors(index, NUM_VECTORS, DIMENSION); + assertEquals(NUM_VECTORS, index.getCount()); + + index.build(BUILD_LIST_SIZE); + + // Search should return sequential IDs (0, 1, 2, ...) + float[] queryVectors = createQueryVectors(1, DIMENSION); + float[] distances = new float[K]; + long[] labels = new long[K]; + + index.search(1, queryVectors, K, SEARCH_LIST_SIZE, distances, labels); + + for (int i = 0; i < K; i++) { + assertTrue( + labels[i] >= 0 && labels[i] < NUM_VECTORS, + "Label " + labels[i] + " out of range"); + } + } + } + + @Test + void testBatchSearch() { + try (Index index = createIndex(MetricType.L2)) { + addVectors(index, NUM_VECTORS, DIMENSION); + index.build(BUILD_LIST_SIZE); + + int numQueries = 5; + float[] queryVectors = createQueryVectors(numQueries, DIMENSION); + float[] distances = new float[numQueries * K]; + long[] labels = new long[numQueries * K]; + + index.search(numQueries, queryVectors, K, SEARCH_LIST_SIZE, distances, labels); + + // Read results for each query + for (int q = 0; q < numQueries; q++) { + for (int n = 0; n < K; n++) { + int idx = q * K + n; + assertTrue(labels[idx] >= 0 && labels[idx] < NUM_VECTORS); + assertTrue(distances[idx] >= 0); + } + } + } + } + + @Test + void testInnerProductMetric() { + try (Index index = createIndex(MetricType.INNER_PRODUCT)) { + assertEquals(MetricType.INNER_PRODUCT, index.getMetricType()); + + addVectors(index, NUM_VECTORS, DIMENSION); + index.build(BUILD_LIST_SIZE); + + float[] queryVectors = createQueryVectors(1, DIMENSION); + float[] distances = new float[K]; + long[] labels = new long[K]; + + index.search(1, queryVectors, K, SEARCH_LIST_SIZE, distances, labels); + + // DiskANN uses distance form for all metrics (lower = closer/more similar). + // For inner product the distance is derived so that results are still in + // ascending order by distance (the most similar result first). + for (int i = 1; i < K; i++) { + assertTrue( + distances[i] >= distances[i - 1], + "Distances should be sorted in ascending order (lower = more similar)"); + } + } + } + + @Test + void testCosineMetric() { + try (Index index = createIndex(MetricType.COSINE)) { + assertEquals(MetricType.COSINE, index.getMetricType()); + + addVectors(index, NUM_VECTORS, DIMENSION); + index.build(BUILD_LIST_SIZE); + + float[] queryVectors = createQueryVectors(1, DIMENSION); + float[] distances = new float[K]; + long[] labels = new long[K]; + + index.search(1, queryVectors, K, SEARCH_LIST_SIZE, distances, labels); + + // Cosine distance should be in [0, 2] range + for (int i = 0; i < K; i++) { + assertTrue(labels[i] >= 0, "Label should be non-negative"); + } + } + } + + @Test + void testSmallIndex() { + int dim = 2; + try (Index index = + Index.create(dim, MetricType.L2, INDEX_TYPE_MEMORY, MAX_DEGREE, BUILD_LIST_SIZE)) { + // Add a few vectors: [1,0], [0,1], [0.7,0.7] + ByteBuffer vectorBuffer = Index.allocateVectorBuffer(3, dim); + FloatBuffer floatView = vectorBuffer.asFloatBuffer(); + floatView.put(0, 1.0f); + floatView.put(1, 0.0f); // position 0: [1, 0] + floatView.put(2, 0.0f); + floatView.put(3, 1.0f); // position 1: [0, 1] + floatView.put(4, 0.7f); + floatView.put(5, 0.7f); // position 2: [0.7, 0.7] + + index.add(3, vectorBuffer); + index.build(BUILD_LIST_SIZE); + + // Query for [1, 0] - should find position 0 as nearest + float[] query = {1.0f, 0.0f}; + float[] distances = new float[1]; + long[] labels = new long[1]; + index.search(1, query, 1, SEARCH_LIST_SIZE, distances, labels); + + assertEquals(0L, labels[0], "Nearest to [1,0] should be position 0"); + assertEquals(0.0f, distances[0], 1e-5f, "Distance to self should be ~0"); + } + } + + @Test + void testSearchResultArrays() { + try (Index index = createIndex(MetricType.L2)) { + addVectors(index, 100, DIMENSION); + index.build(BUILD_LIST_SIZE); + + int numQueries = 3; + int k = 5; + float[] queryVectors = createQueryVectors(numQueries, DIMENSION); + float[] distances = new float[numQueries * k]; + long[] labels = new long[numQueries * k]; + + index.search(numQueries, queryVectors, k, SEARCH_LIST_SIZE, distances, labels); + + // Test reading individual results + for (int q = 0; q < numQueries; q++) { + for (int n = 0; n < k; n++) { + int idx = q * k + n; + assertTrue(labels[idx] >= 0 && labels[idx] < 100); + assertTrue(distances[idx] >= 0); + } + } + } + } + + @Test + void testBufferAllocationHelpers() { + // Test vector buffer allocation + ByteBuffer vectorBuffer = Index.allocateVectorBuffer(10, DIMENSION); + assertTrue(vectorBuffer.isDirect()); + assertEquals(ByteOrder.nativeOrder(), vectorBuffer.order()); + assertEquals(10 * DIMENSION * Float.BYTES, vectorBuffer.capacity()); + } + + @Test + void testErrorHandling() { + // Test buffer validation - wrong size buffer + try (Index index = createIndex(MetricType.L2)) { + ByteBuffer wrongSizeBuffer = + ByteBuffer.allocateDirect(10).order(ByteOrder.nativeOrder()); + assertThrows( + IllegalArgumentException.class, + () -> { + index.add(1, wrongSizeBuffer); + }); + } + + // Test non-direct buffer + try (Index index = createIndex(MetricType.L2)) { + ByteBuffer heapBuffer = ByteBuffer.allocate(DIMENSION * Float.BYTES); + assertThrows( + IllegalArgumentException.class, + () -> { + index.add(1, heapBuffer); + }); + } + + // Test serialize with non-direct buffer + try (Index index = createIndex(MetricType.L2)) { + ByteBuffer heapBuffer = ByteBuffer.allocate(100); + assertThrows( + IllegalArgumentException.class, + () -> { + index.serialize(heapBuffer); + }); + } + + // Test closed index + Index closedIndex = createIndex(MetricType.L2); + closedIndex.close(); + assertThrows( + IllegalStateException.class, + () -> { + closedIndex.getCount(); + }); + } + + @Test + void testQueryVectorArrayValidation() { + try (Index index = createIndex(MetricType.L2)) { + addVectors(index, 10, DIMENSION); + index.build(BUILD_LIST_SIZE); + + // Query vectors array too small + float[] tooSmall = new float[DIMENSION - 1]; + float[] distances = new float[K]; + long[] labels = new long[K]; + assertThrows( + IllegalArgumentException.class, + () -> { + index.search(1, tooSmall, K, SEARCH_LIST_SIZE, distances, labels); + }); + + // Distances array too small + float[] query = createQueryVectors(1, DIMENSION); + float[] smallDistances = new float[K - 1]; + assertThrows( + IllegalArgumentException.class, + () -> { + index.search(1, query, K, SEARCH_LIST_SIZE, smallDistances, labels); + }); + + // Labels array too small + long[] smallLabels = new long[K - 1]; + assertThrows( + IllegalArgumentException.class, + () -> { + index.search(1, query, K, SEARCH_LIST_SIZE, distances, smallLabels); + }); + } + } + + private Index createIndex(MetricType metricType) { + return Index.create(DIMENSION, metricType, INDEX_TYPE_MEMORY, MAX_DEGREE, BUILD_LIST_SIZE); + } + + /** Add random vectors to the index. */ + private void addVectors(Index index, int n, int d) { + ByteBuffer vectorBuffer = createVectorBuffer(n, d); + index.add(n, vectorBuffer); + } + + /** Create a direct ByteBuffer with random vectors. */ + private ByteBuffer createVectorBuffer(int n, int d) { + ByteBuffer buffer = Index.allocateVectorBuffer(n, d); + FloatBuffer floatView = buffer.asFloatBuffer(); + + Random random = new Random(42); + for (int i = 0; i < n * d; i++) { + floatView.put(i, random.nextFloat()); + } + + return buffer; + } + + /** Create a float array with random query vectors. */ + private float[] createQueryVectors(int n, int d) { + float[] vectors = new float[n * d]; + Random random = new Random(42); + for (int i = 0; i < n * d; i++) { + vectors[i] = random.nextFloat(); + } + return vectors; + } +} diff --git a/paimon-diskann/pom.xml b/paimon-diskann/pom.xml new file mode 100644 index 000000000000..55f5d45df16f --- /dev/null +++ b/paimon-diskann/pom.xml @@ -0,0 +1,40 @@ + + + + 4.0.0 + + + paimon-parent + org.apache.paimon + 1.4-SNAPSHOT + + + paimon-diskann + Paimon : DiskANN + pom + + + paimon-diskann-jni + paimon-diskann-index + paimon-diskann-e2e-test + + diff --git a/pom.xml b/pom.xml index 879f4bcbc065..999a677faa6f 100644 --- a/pom.xml +++ b/pom.xml @@ -532,6 +532,15 @@ under the License. true + + paimon-diskann + + paimon-diskann + + + true + +