diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 4fbb565..4b99938 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -4,14 +4,14 @@ on: push: paths: - .github/workflows/audit.yml - - '**/Cargo.toml' - - '**/Cargo.lock' + - "**/Cargo.toml" + - "**/Cargo.lock" pull_request: paths: - - '**/Cargo.toml' - - '**/Cargo.lock' + - "**/Cargo.toml" + - "**/Cargo.lock" schedule: - - cron: '0 6 * * 1' # Monday at 6 AM UTC + - cron: "0 6 * * 1" # Monday at 6 AM UTC workflow_dispatch: permissions: @@ -21,15 +21,15 @@ jobs: audit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install Rust toolchain - uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 + uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 with: - cache: true # toolchain/components are specified in rust-toolchain.toml + cache: true # toolchain/components are specified in rust-toolchain.toml - name: Cache advisory database - uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1 with: path: ~/.cargo/advisory-db key: advisory-db-${{ github.ref_name }}-v1 @@ -44,7 +44,7 @@ jobs: cargo audit - name: Upload audit results - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 if: always() with: name: audit-results diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e4c449e..86d2021 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,5 @@ name: CI +# cspell:ignore rhysd binstall mktemp sha256sum sha256sums shasum permissions: contents: read concurrency: @@ -21,12 +22,17 @@ on: env: CARGO_TERM_COLOR: always RUST_BACKTRACE: 1 + ACTIONLINT_VERSION: "1.7.10" + MARKDOWNLINT_VERSION: "0.47.0" + CSPELL_VERSION: "9.4.0" + SHFMT_VERSION: "3.12.0" + UV_VERSION: "0.9.21" jobs: build: runs-on: ${{ matrix.os }} strategy: - fail-fast: false # Continue other jobs if one fails + fail-fast: false # Continue other jobs if one fails matrix: os: - ubuntu-latest @@ -41,47 +47,174 @@ jobs: target: x86_64-pc-windows-msvc steps: - - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install Rust toolchain - uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 + uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 with: target: ${{ matrix.target }} - cache: true # Built-in caching; toolchain/components are specified in rust-toolchain.toml + cache: true # Built-in caching + # toolchain, components, etc. are specified in rust-toolchain.toml - name: Install just if: matrix.os != 'windows-latest' - uses: taiki-e/install-action@bfc291e1e39400b67eda124e4a7b4380e93b3390 # v2.65.0 + uses: taiki-e/install-action@4c6723ec9c638cccae824b8957c5085b695c8085 # v2.65.7 with: tool: just - name: Install uv (for Python scripts and pytest) if: matrix.os != 'windows-latest' - uses: astral-sh/setup-uv@681c641aba71e4a1c380be3ab5e12ad51f415867 # v7.1.6 + uses: astral-sh/setup-uv@681c641aba71e4a1c380be3ab5e12ad51f415867 # v7.1.6 with: - version: "0.9.17" # Pinned for reproducible CI + version: ${{ env.UV_VERSION }} - name: Install Node.js (for markdownlint and cspell) if: matrix.os != 'windows-latest' uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 with: - node-version: '20' + node-version: "20" - name: Install Node.js packages if: matrix.os != 'windows-latest' run: | - npm install -g markdownlint-cli cspell + npm install -g markdownlint-cli@${{ env.MARKDOWNLINT_VERSION }} cspell@${{ env.CSPELL_VERSION }} - - name: Install jq (Linux) + - name: Install taplo (for TOML formatting and linting) + if: matrix.os != 'windows-latest' + uses: taiki-e/install-action@4c6723ec9c638cccae824b8957c5085b695c8085 # v2.65.7 + with: + tool: taplo-cli + + - name: Install actionlint (Linux/macOS) + if: matrix.os != 'windows-latest' + run: | + set -euo pipefail + + # actionlint is published as prebuilt binaries (Go), not a Rust crate. + # Install directly from rhysd/actionlint releases to avoid cargo-binstall fallback failures. + # Verify SHA256 checksums from the upstream release for supply-chain hardening. + OS="$(uname -s)" + ARCH="$(uname -m)" + + case "$OS" in + Linux) ACTIONLINT_OS="linux" ;; + Darwin) ACTIONLINT_OS="darwin" ;; + *) + echo "Unsupported OS for actionlint: $OS" >&2 + exit 1 + ;; + esac + + case "$ARCH" in + x86_64|amd64) ACTIONLINT_ARCH="amd64" ;; + arm64|aarch64) ACTIONLINT_ARCH="arm64" ;; + *) + echo "Unsupported architecture for actionlint: $ARCH" >&2 + exit 1 + ;; + esac + + verify_sha256() { + local checksum_file="$1" + if command -v sha256sum >/dev/null 2>&1; then + sha256sum -c "$checksum_file" + else + shasum -a 256 -c "$checksum_file" + fi + } + + VERSION="${ACTIONLINT_VERSION}" + TARBALL="actionlint_${VERSION}_${ACTIONLINT_OS}_${ACTIONLINT_ARCH}.tar.gz" + CHECKSUMS_FILE="actionlint_${VERSION}_checksums.txt" + BASE_URL="https://github.com/rhysd/actionlint/releases/download/v${VERSION}" + + tmpdir="$(mktemp -d)" + trap 'rm -rf "$tmpdir"' EXIT + + curl -fsSL "${BASE_URL}/${TARBALL}" -o "$tmpdir/$TARBALL" + curl -fsSL "${BASE_URL}/${CHECKSUMS_FILE}" -o "$tmpdir/$CHECKSUMS_FILE" + + orig_dir="$PWD" + cd "$tmpdir" + awk -v f="$TARBALL" '$NF==f {print; found=1} END {exit found?0:1}' "$CHECKSUMS_FILE" > checksum.txt + verify_sha256 checksum.txt + cd "$orig_dir" + + tar -xzf "$tmpdir/$TARBALL" -C "$tmpdir" + + actionlint_path="$(find "$tmpdir" -type f -name actionlint | head -n 1)" + if [[ -z "$actionlint_path" ]]; then + echo "actionlint binary not found in $TARBALL" >&2 + exit 1 + fi + + sudo install -m 0755 "$actionlint_path" /usr/local/bin/actionlint + + - name: actionlint -version + if: matrix.os != 'windows-latest' + run: actionlint -version + + - name: Install additional tools (Linux) if: matrix.os == 'ubuntu-latest' run: | + # Install shellcheck, jq, and yamllint sudo apt-get update - sudo apt-get install -y jq + sudo apt-get install -y shellcheck jq yamllint + + # Install shfmt (pinned for CI consistency) + SHFMT_ASSET="shfmt_v${SHFMT_VERSION}_linux_amd64" + SHFMT_BASE_URL="https://github.com/mvdan/sh/releases/download/v${SHFMT_VERSION}" + + tmpdir="$(mktemp -d)" + trap 'rm -rf "$tmpdir"' EXIT + + curl -fsSL \ + "${SHFMT_BASE_URL}/${SHFMT_ASSET}" \ + -o "$tmpdir/${SHFMT_ASSET}" + curl -fsSL \ + "${SHFMT_BASE_URL}/sha256sums.txt" \ + -o "$tmpdir/sha256sums.txt" + + ( + cd "$tmpdir" + awk -v f="${SHFMT_ASSET}" '$NF==f {print; found=1} END {exit found?0:1}' sha256sums.txt > checksum.txt + sha256sum -c checksum.txt + ) - - name: Install jq (macOS) + sudo install -m 0755 "$tmpdir/${SHFMT_ASSET}" /usr/local/bin/shfmt + + - name: Install additional tools (macOS) if: matrix.os == 'macos-latest' run: | - brew install jq + # Install shellcheck, jq, and yamllint via Homebrew + brew install shellcheck jq yamllint + + # Install shfmt (pinned for CI consistency with Linux) + SHFMT_ARCH="amd64" + if [[ "$(uname -m)" == "arm64" ]]; then + SHFMT_ARCH="arm64" + fi + + SHFMT_ASSET="shfmt_v${SHFMT_VERSION}_darwin_${SHFMT_ARCH}" + SHFMT_BASE_URL="https://github.com/mvdan/sh/releases/download/v${SHFMT_VERSION}" + + tmpdir="$(mktemp -d)" + trap 'rm -rf "$tmpdir"' EXIT + + curl -fsSL \ + "${SHFMT_BASE_URL}/${SHFMT_ASSET}" \ + -o "$tmpdir/${SHFMT_ASSET}" + curl -fsSL \ + "${SHFMT_BASE_URL}/sha256sums.txt" \ + -o "$tmpdir/sha256sums.txt" + + ( + cd "$tmpdir" + awk -v f="${SHFMT_ASSET}" '$NF==f {print; found=1} END {exit found?0:1}' sha256sums.txt > checksum.txt + shasum -a 256 -c checksum.txt + ) + + sudo install -m 0755 "$tmpdir/${SHFMT_ASSET}" /usr/local/bin/shfmt - name: Run CI checks (Linux/macOS) if: matrix.os != 'windows-latest' diff --git a/.github/workflows/codacy.yml b/.github/workflows/codacy.yml index 274942d..87017f6 100644 --- a/.github/workflows/codacy.yml +++ b/.github/workflows/codacy.yml @@ -25,7 +25,7 @@ on: # The branches below must be a subset of the branches above branches: ["main"] schedule: - - cron: '42 0 * * 1' + - cron: "42 0 * * 1" workflow_dispatch: permissions: @@ -47,7 +47,7 @@ jobs: steps: # Checkout the repository to the GitHub Actions runner - name: Checkout code - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Set Codacy paths run: | @@ -144,7 +144,7 @@ jobs: # Upload the identified SARIF file - name: Upload identified SARIF file if: always() && env.SARIF_FILE != '' - uses: github/codeql-action/upload-sarif@b36bf259c813715f76eafece573914b94412cd13 # v3 + uses: github/codeql-action/upload-sarif@b36bf259c813715f76eafece573914b94412cd13 # v3 with: sarif_file: ${{ env.SARIF_FILE }} continue-on-error: true diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index cb54226..12be633 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -24,17 +24,17 @@ jobs: TARPAULIN_VERSION: "0.32.8" steps: - name: Checkout repository - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: - fetch-depth: 0 # Needed for Codecov diff analysis + fetch-depth: 0 # Needed for Codecov diff analysis - name: Install Rust toolchain - uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 + uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 with: - cache: true # toolchain/components are specified in rust-toolchain.toml + cache: true # toolchain/components are specified in rust-toolchain.toml - name: Cache tarpaulin - uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1 with: path: ~/.cargo/bin/cargo-tarpaulin key: tarpaulin-${{ runner.os }}-${{ env.TARPAULIN_VERSION }} @@ -48,7 +48,7 @@ jobs: fi - name: Install just - uses: taiki-e/install-action@bfc291e1e39400b67eda124e4a7b4380e93b3390 # v2.65.0 + uses: taiki-e/install-action@bfc291e1e39400b67eda124e4a7b4380e93b3390 # v2.65.0 with: tool: just @@ -76,7 +76,7 @@ jobs: - name: Upload coverage to Codecov if: ${{ success() && hashFiles('coverage/cobertura.xml') != '' }} - uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de # v5.5.2 + uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de # v5.5.2 with: files: coverage/cobertura.xml flags: unittests @@ -86,7 +86,7 @@ jobs: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - name: Archive coverage results - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 if: always() with: name: coverage-report diff --git a/.github/workflows/rust-clippy.yml b/.github/workflows/rust-clippy.yml index 91e2480..ed82805 100644 --- a/.github/workflows/rust-clippy.yml +++ b/.github/workflows/rust-clippy.yml @@ -6,7 +6,7 @@ name: "Clippy Security Analysis" on: # Only run on schedule and manual trigger to avoid duplication with CI schedule: - - cron: '17 22 * * 0' # Weekly on Sunday + - cron: "17 22 * * 0" # Weekly on Sunday workflow_dispatch: # Run on main branch pushes for security scanning push: @@ -24,15 +24,15 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install Rust toolchain - uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 + uses: actions-rust-lang/setup-rust-toolchain@1780873c7b576612439a134613cc4cc74ce5538c # v1.15.2 with: - cache: true # toolchain/components are specified in rust-toolchain.toml + cache: true # toolchain/components are specified in rust-toolchain.toml - name: Cache clippy tools - uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1 with: path: | ~/.cargo/bin/clippy-sarif @@ -58,7 +58,7 @@ jobs: continue-on-error: true - name: Upload SARIF results - uses: github/codeql-action/upload-sarif@b36bf259c813715f76eafece573914b94412cd13 # v3 + uses: github/codeql-action/upload-sarif@b36bf259c813715f76eafece573914b94412cd13 # v3 with: sarif_file: rust-clippy-results.sarif category: "clippy" diff --git a/.gitignore b/.gitignore index 43fa25c..aa6c938 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ venv/ .pytest_cache/ .mypy_cache/ uv.lock +/node_modules/ +/package-lock.json diff --git a/.taplo.toml b/.taplo.toml new file mode 100644 index 0000000..9830e2f --- /dev/null +++ b/.taplo.toml @@ -0,0 +1,20 @@ +# Taplo configuration mirroring Rust-lang / Cargo TOML philosophy +# Goal: correctness + stability, NOT opinionated formatting + +[formatting] +# Match Cargo-style conservative formatting +align_entries = false +align_comments = false +reorder_keys = false +reorder_arrays = false +array_auto_expand = false +array_auto_collapse = false +compact_arrays = false +compact_inline_tables = false +column_width = 0 # unlimited, like rustfmt defaults +indent_string = " " # 4 spaces (Cargo standard) +trailing_newline = true +crlf = false + +# Note: Taplo config schema currently supports formatting/schema options only. +# Linting behavior is configured via `taplo lint` defaults and CLI flags. diff --git a/.yamllint b/.yamllint new file mode 100644 index 0000000..4884a52 --- /dev/null +++ b/.yamllint @@ -0,0 +1,18 @@ +--- +extends: default + +ignore: | + target*/ + target-*/ + **/target*/ + +rules: + line-length: + max: 120 + truthy: + allowed-values: ['true', 'false', 'on', 'off', 'yes', 'no'] + # Prettier uses a single space before inline comments ("foo: bar # comment"). + comments: + min-spaces-from-content: 1 + comments-indentation: disable + document-start: disable diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..e520ccf --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,24 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +type: software +title: "la-stack: Fast, stack-allocated linear algebra for fixed dimensions in Rust" +version: 0.1.2 +url: "https://github.com/acgetchell/la-stack" +repository-code: "https://github.com/acgetchell/la-stack" +authors: + - family-names: "Getchell" + given-names: "Adam" + email: "adam@adamgetchell.org" + orcid: "https://orcid.org/0000-0002-0797-0021" +keywords: + - "linear algebra" + - "geometry" + - "const generics" + - "LU" + - "LDLT" + - "Rust" +abstract: >- + la-stack is a Rust library providing fast, stack-allocated linear algebra primitives + and factorization routines (LU with partial pivoting and LDLT for symmetric SPD/PSD matrices) + for fixed small dimensions using const generics. +license: BSD-3-Clause diff --git a/Cargo.lock b/Cargo.lock index 43e283a..dafc92c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -693,7 +693,7 @@ dependencies = [ [[package]] name = "la-stack" -version = "0.1.1" +version = "0.1.2" dependencies = [ "approx", "criterion", diff --git a/Cargo.toml b/Cargo.toml index 3f42210..172db11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "la-stack" -version = "0.1.1" +version = "0.1.2" edition = "2024" rust-version = "1.92" license = "BSD-3-Clause" @@ -8,16 +8,16 @@ description = "Small, stack-allocated linear algebra for fixed dimensions" readme = "README.md" documentation = "https://docs.rs/la-stack" repository = "https://github.com/acgetchell/la-stack" -categories = ["mathematics", "science"] -keywords = ["linear-algebra", "geometry", "const-generics"] +categories = [ "mathematics", "science" ] +keywords = [ "linear-algebra", "geometry", "const-generics" ] [dependencies] # Intentionally empty [dev-dependencies] approx = "0.5.1" -criterion = { version = "0.8.1", features = ["html_reports"] } -faer = { version = "0.23.2", default-features = false, features = ["std", "linalg"] } +criterion = { version = "0.8.1", features = [ "html_reports" ] } +faer = { version = "0.23.2", default-features = false, features = [ "std", "linalg" ] } nalgebra = "0.34.1" pastey = "0.2.1" proptest = "1.9.0" diff --git a/README.md b/README.md index e991be5..b0b4e15 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ while keeping the API intentionally small and explicit. - `Vector` for fixed-length vectors (`[f64; D]` today) - `Matrix` for fixed-size square matrices (`[[f64; D]; D]` today) - `Lu` for LU factorization with partial pivoting (solve + det) +- `Ldlt` for LDLT factorization without pivoting (solve + det; symmetric SPD/PSD) ## โœจ Design goals @@ -80,13 +81,35 @@ for (x_i, e_i) in x.iter().zip(expected.iter()) { } ``` +Compute a determinant for a symmetric SPD matrix via LDLT (no pivoting). + +For symmetric positive-definite matrices, `LDL^T` is essentially a square-root-free form of the Cholesky decomposition +(you can recover a Cholesky factor by absorbing `sqrt(D)` into `L`): + +```rust +use la_stack::prelude::*; + +// This matrix is symmetric positive-definite (A = L*L^T) so LDLT works without pivoting. +let a = Matrix::<5>::from_rows([ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 2.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 2.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 2.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 2.0], +]); + +let det = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap().det(); +assert!((det - 1.0).abs() <= 1e-12); +``` + ## ๐Ÿงฉ API at a glance | Type | Storage | Purpose | Key methods | |---|---|---|---| | `Vector` | `[f64; D]` | Fixed-length vector | `new`, `zero`, `dot`, `norm2_sq` | -| `Matrix` | `[[f64; D]; D]` | Fixed-size square matrix | `from_rows`, `zero`, `identity`, `lu`, `det` | +| `Matrix` | `[[f64; D]; D]` | Fixed-size square matrix | `from_rows`, `zero`, `identity`, `lu`, `ldlt`, `det` | | `Lu` | `Matrix` + pivot array | Factorization for solves/det | `solve_vec`, `det` | +| `Ldlt` | `Matrix` | Factorization for symmetric SPD/PSD solves/det | `solve_vec`, `det` | Storage shown above reflects the current `f64` implementation. @@ -107,12 +130,23 @@ A short contributor workflow: ```bash cargo install just -just ci # lint + fast tests + bench compile -just commit-check # lint + all tests + examples +just setup # install/verify dev tools + sync Python deps +just check # lint/validate (non-mutating) +just fix # apply auto-fixes (mutating) +just ci # lint + tests + examples + bench compile ``` For the full set of developer commands, see `just --list` and `WARP.md`. +## ๐Ÿ“ Citation + +If you use this library in academic work, please cite it using [CITATION.cff](CITATION.cff) (or GitHub's +"Cite this repository" feature). A Zenodo DOI will be added for tagged releases. + +## ๐Ÿ“š References + +For canonical references to LU / `LDL^T` algorithms used by this crate, see [REFERENCES.md](REFERENCES.md). + ## ๐Ÿ“Š Benchmarks (vs nalgebra/faer) ![LU solve (factor + solve): median time vs dimension](docs/assets/bench/vs_linalg_lu_solve_median.svg) diff --git a/REFERENCES.md b/REFERENCES.md new file mode 100644 index 0000000..89d4a8c --- /dev/null +++ b/REFERENCES.md @@ -0,0 +1,55 @@ +# References and citations + +## How to cite this library + +If you use this library in your research or project, please cite it using the information in +[CITATION.cff](CITATION.cff). This file contains structured citation metadata that can be +processed by GitHub and other platforms. + +A Zenodo DOI will be added for tagged releases. + +## Linear algebra algorithms + +### LU decomposition (Gaussian elimination with partial pivoting) + +The LU implementation in `la-stack` follows the standard Gaussian elimination / LU factorization +approach with partial pivoting for numerical stability. + +See references [1-3] below. + +### LDL^T factorization (symmetric SPD/PSD) + +The LDL^T (often abbreviated "LDLT") implementation in `la-stack` is intended for symmetric positive +definite (SPD) and positive semi-definite (PSD) matrices (e.g. Gram matrices), and does not perform +pivoting. + +For background on the SPD/PSD setting, see [4-5]. For pivoted variants used for symmetric *indefinite* +matrices, see [6]. + +## References + +### LU / Gaussian elimination + +1. Trefethen, Lloyd N., and Robert S. Schreiber. "Average-case stability of Gaussian elimination." + *SIAM Journal on Matrix Analysis and Applications* 11.3 (1990): 335โ€“360. + [PDF](https://people.maths.ox.ac.uk/trefethen/publication/PDF/1990_44.pdf) +2. Businger, P. A. "Monitoring the Numerical Stability of Gaussian Elimination." + *Numerische Mathematik* 16 (1970/71): 360โ€“361. + [Full text](https://eudml.org/doc/132040) +3. Huang, Han, and K. Tikhomirov. "Average-case analysis of the Gaussian elimination with partial pivoting." + *Probability Theory and Related Fields* 189 (2024): 501โ€“567. + [Open-access PDF](https://link.springer.com/article/10.1007/s00440-024-01276-2) (also: [arXiv:2206.01726](https://arxiv.org/abs/2206.01726)) + +### LDL^T / Cholesky (symmetric SPD/PSD) + +4. Cholesky, Andre-Louis. "On the numerical solution of systems of linear equations" + (manuscript dated 2 Dec 1910; published 2005). + Scan + English analysis: [BibNum](https://www.bibnum.education.fr/mathematiques/algebre/sur-la-resolution-numerique-des-systemes-d-equations-lineaires) +5. Brezinski, Claude. "La methode de Cholesky." (2005). + [PDF](https://eudml.org/doc/252115) + +### Pivoted LDL^T (symmetric indefinite) + +6. Bunch, J. R., L. Kaufman, and B. N. Parlett. "Decomposition of a Symmetric Matrix." + *Numerische Mathematik* 27 (1976/1977): 95โ€“110. + [Full text](https://eudml.org/doc/132435) diff --git a/WARP.md b/WARP.md index 06c5920..3c97a7b 100644 --- a/WARP.md +++ b/WARP.md @@ -12,27 +12,34 @@ When making changes in this repo, prioritize (in order): ## Common commands +- All tests (Rust + Python): `just test-all` +- Benchmarks: `cargo bench` (or `just bench`) - Build (debug): `cargo build` (or `just build`) - Build (release): `cargo build --release` (or `just build-release`) -- Fast compile check (no binary produced): `cargo check` (or `just check`) -- Run tests: `cargo test` (or `just test`) -- Run a single test (by name filter): `cargo test solve_2x2_basic` (or the full path: `cargo test lu::tests::solve_2x2_basic`) +- CI simulation (lint + tests + examples + bench compile): `just ci` +- Coverage (CI XML): `just coverage-ci` +- Coverage (HTML): `just coverage` +- Fast compile check (no binary produced): `cargo check` (or `just check-fast`) +- Fast Rust tests (lib + doc): `just test` - Format: `cargo fmt` (or `just fmt`) +- Integration tests: `just test-integration` - Lint (Clippy): `cargo clippy --all-targets --all-features -- -D warnings` (or `just clippy`) -- Spell check: `just spell-check` (uses `cspell.json` at repo root; keep the `words` list sorted lexicographically) -- Run benchmarks: `cargo bench` (or `just bench`) +- Lint/validate: `just check` +- Pre-commit validation: `just ci` +- Python tests: `just test-python` +- Run a single test (by name filter): `cargo test solve_2x2_basic` (or the full path: `cargo test lu::tests::solve_2x2_basic`) - Run examples: `just examples` (or `cargo run --example det_5x5` / `cargo run --example solve_5x5`) -- CI simulation (lint + tests + bench compile): `just ci` -- Pre-commit validation: `just commit-check` +- Spell check: `just spell-check` (uses `cspell.json` at repo root; keep the `words` list sorted lexicographically) ## Code structure (big picture) - This is a single Rust *library crate* (no `src/main.rs`). The crate root is `src/lib.rs`. - The linear algebra implementation is split across: - - `src/lib.rs`: crate root + shared items (`LaError`, `DEFAULT_PIVOT_TOL`) + re-exports + - `src/lib.rs`: crate root + shared items (`LaError`, `DEFAULT_SINGULAR_TOL`, `DEFAULT_PIVOT_TOL`) + re-exports - `src/vector.rs`: `Vector` (`[f64; D]`) - `src/matrix.rs`: `Matrix` (`[[f64; D]; D]`) + helpers (`get`, `set`, `inf_norm`, `det`) - `src/lu.rs`: `Lu` factorization with partial pivoting (`solve_vec`, `det`) + - `src/ldlt.rs`: `Ldlt` factorization without pivoting for symmetric SPD/PSD matrices (`solve_vec`, `det`) - A minimal `justfile` exists for common workflows (see `just --list`). - The public API re-exports these items from `src/lib.rs`. - Dev-only benchmarks live in `benches/vs_linalg.rs` (Criterion + nalgebra/faer comparison). diff --git a/cspell.json b/cspell.json index c8d20b0..9ee589b 100644 --- a/cspell.json +++ b/cspell.json @@ -4,7 +4,10 @@ "useGitignore": true, "words": [ "acgetchell", + "addopts", "blas", + "Brezinski", + "Businger", "capsys", "Clippy", "clippy", @@ -14,6 +17,7 @@ "doctests", "elif", "endgroup", + "esac", "f128", "f32", "f64", @@ -22,16 +26,24 @@ "generics", "Getchell", "gnuplot", + "Golub", + "Higham", "Justfile", "justfile", "keepends", "laerror", "lapack", + "LDLT", + "Ldlt", + "ldlt", "linalg", "linespoints", "logscale", "lu", "markdownlint", + "Mathematik", + "methode", + "minversion", "MSRV", "msvc", "mult", @@ -43,7 +55,10 @@ "noplot", "nrhs", "nrows", + "Numerische", "openblas", + "orcid", + "Parlett", "pastey", "patchlevel", "pipefail", @@ -55,11 +70,23 @@ "pytest", "rug", "RUSTDOCFLAGS", + "RUSTFLAGS", + "rustup", + "samply", "sarif", + "Schreiber", "semgrep", "setuptools", "shellcheck", + "SHFMT", + "shfmt", + "submatrix", "taiki", + "Taplo", + "taplo", + "testpaths", + "Tikhomirov", + "Trefethen", "tridiagonal", "unittests", "usize", @@ -67,7 +94,8 @@ "xlabel", "xtics", "ylabel", - "yerrorlines" + "yerrorlines", + "Zenodo" ], "ignorePaths": [ "**/.git/**", diff --git a/justfile b/justfile index 4230122..96a088a 100644 --- a/justfile +++ b/justfile @@ -6,24 +6,67 @@ # Use bash with strict error handling for all recipes set shell := ["bash", "-euo", "pipefail", "-c"] -# Internal helper: ensure uv is installed -_ensure-uv: +# Internal helpers: ensure external tooling is installed +_ensure-actionlint: #!/usr/bin/env bash set -euo pipefail - command -v uv >/dev/null || { echo "โŒ 'uv' not found. Install: https://github.com/astral-sh/uv (macOS: brew install uv)"; exit 1; } + command -v actionlint >/dev/null || { echo "โŒ 'actionlint' not found. See 'just setup' or https://github.com/rhysd/actionlint"; exit 1; } -# GitHub Actions workflow validation (optional) -action-lint: +_ensure-jq: #!/usr/bin/env bash set -euo pipefail - if ! command -v actionlint >/dev/null; then - echo "โš ๏ธ 'actionlint' not found. Install: https://github.com/rhysd/actionlint" + command -v jq >/dev/null || { echo "โŒ 'jq' not found. See 'just setup' or install: brew install jq"; exit 1; } + +_ensure-npx: + #!/usr/bin/env bash + set -euo pipefail + command -v npx >/dev/null || { echo "โŒ 'npx' not found. See 'just setup' or install Node.js (for npx tools): https://nodejs.org"; exit 1; } + +_ensure-prettier-or-npx: + #!/usr/bin/env bash + set -euo pipefail + if command -v prettier >/dev/null; then exit 0 fi + command -v npx >/dev/null || { + echo "โŒ Neither 'prettier' nor 'npx' found. Install via npm (recommended): npm i -g prettier" + echo " Or install Node.js (for npx): https://nodejs.org" + exit 1 + } + +_ensure-shellcheck: + #!/usr/bin/env bash + set -euo pipefail + command -v shellcheck >/dev/null || { echo "โŒ 'shellcheck' not found. See 'just setup' or https://www.shellcheck.net"; exit 1; } + +_ensure-shfmt: + #!/usr/bin/env bash + set -euo pipefail + command -v shfmt >/dev/null || { echo "โŒ 'shfmt' not found. See 'just setup' or install: brew install shfmt"; exit 1; } + +_ensure-taplo: + #!/usr/bin/env bash + set -euo pipefail + command -v taplo >/dev/null || { echo "โŒ 'taplo' not found. See 'just setup' or install: brew install taplo (or: cargo install taplo-cli)"; exit 1; } + +_ensure-uv: + #!/usr/bin/env bash + set -euo pipefail + command -v uv >/dev/null || { echo "โŒ 'uv' not found. See 'just setup' or https://github.com/astral-sh/uv"; exit 1; } + +_ensure-yamllint: + #!/usr/bin/env bash + set -euo pipefail + command -v yamllint >/dev/null || { echo "โŒ 'yamllint' not found. See 'just setup' or install: brew install yamllint"; exit 1; } + +# GitHub Actions workflow validation +action-lint: _ensure-actionlint + #!/usr/bin/env bash + set -euo pipefail files=() while IFS= read -r -d '' file; do files+=("$file") - done < <(git ls-files -z '.github/workflows/*.yaml' '.github/workflows/*.yml') + done < <(git ls-files -z '.github/workflows/*.yml' '.github/workflows/*.yaml') if [ "${#files[@]}" -gt 0 ]; then printf '%s\0' "${files[@]}" | xargs -0 actionlint else @@ -34,8 +77,10 @@ action-lint: bench: cargo bench +# Compile benchmarks without running them, treating warnings as errors. +# This catches bench/release-profile-only warnings that won't show up in normal debug-profile runs. bench-compile: - cargo bench --no-run + RUSTFLAGS='-D warnings' cargo bench --no-run # Bench the la-stack vs nalgebra/faer comparison suite. bench-vs-linalg filter="": @@ -66,25 +111,29 @@ build: build-release: cargo build --release -check: +# Check (non-mutating): run all linters/validators +check: lint + @echo "โœ… Checks complete!" + +# Fast compile check (no binary produced) +check-fast: cargo check -# CI simulation (matches delaunay's `just ci` shape) -ci: lint test-all bench-compile - @echo "๐ŸŽฏ CI simulation complete!" +# CI simulation: comprehensive validation (matches CI expectations) +# Runs: checks + all tests (Rust + Python) + examples + bench compile +ci: check bench-compile test-all examples + @echo "๐ŸŽฏ CI checks complete!" +# Clean build artifacts clean: cargo clean + rm -rf target/tarpaulin + rm -rf coverage # Code quality and formatting clippy: cargo clippy --workspace --all-targets --all-features -- -D warnings -W clippy::pedantic -W clippy::nursery -W clippy::cargo -# Pre-commit workflow: comprehensive validation before committing -# Runs: linting + all Rust tests (lib + doc + integration) + examples -commit-check: lint test-all bench-compile examples - @echo "๐Ÿš€ Ready to commit! All checks passed!" - # Coverage (cargo-tarpaulin) # # Common tarpaulin arguments for all coverage runs @@ -124,6 +173,7 @@ coverage-ci: default: @just --list +# Documentation build check doc-check: RUSTDOCFLAGS='-D warnings' cargo doc --no-deps @@ -132,22 +182,61 @@ examples: cargo run --quiet --example det_5x5 cargo run --quiet --example solve_5x5 +# Fix (mutating): apply formatters/auto-fixes +fix: toml-fmt fmt python-fix shell-fmt markdown-fix yaml-fix + @echo "โœ… Fixes applied!" + +# Rust formatting fmt: cargo fmt --all fmt-check: - cargo fmt --check + cargo fmt --all -- --check + +help-workflows: + @echo "Common Just workflows:" + @echo " just check # Run lint/validators (non-mutating)" + @echo " just check-fast # Fast compile check (cargo check)" + @echo " just ci # Full CI simulation (check + tests + examples + bench compile)" + @echo " just fix # Apply formatters/auto-fixes (mutating)" + @echo " just setup # Install/verify dev tools + sync Python deps" + @echo "" + @echo "Benchmarks:" + @echo " just bench # Run benchmarks" + @echo " just bench-compile # Compile benches with warnings-as-errors" + @echo " just bench-vs-linalg # Run vs_linalg bench (optional filter)" + @echo " just bench-vs-linalg-quick # Quick vs_linalg bench (reduced samples)" + @echo "" + @echo "Benchmark plotting:" + @echo " just plot-vs-linalg # Plot Criterion results (CSV + SVG)" + @echo " just plot-vs-linalg-readme # Plot + update README benchmark table" + @echo "" + @echo "Setup:" + @echo " just setup # Setup project environment (depends on setup-tools)" + @echo " just setup-tools # Install/verify external tooling (best-effort)" + @echo "" + @echo "Testing:" + @echo " just coverage # Generate coverage report (HTML)" + @echo " just coverage-ci # Generate coverage for CI (XML)" + @echo " just examples # Run examples" + @echo " just test # Lib + doc tests (fast)" + @echo " just test-all # All tests (Rust + Python)" + @echo " just test-integration # Integration tests" + @echo " just test-python # Python tests only (pytest)" + @echo "" + @echo "Note: Some recipes require external tools. Run 'just setup-tools' (tooling) or 'just setup' (full env) first." # Lint groups (delaunay-style) lint: lint-code lint-docs lint-config -lint-code: fmt-check clippy doc-check python-lint +lint-code: fmt-check clippy doc-check python-check shell-check -lint-config: validate-json action-lint +lint-config: validate-json toml-lint toml-fmt-check yaml-lint action-lint -lint-docs: markdown-lint spell-check +lint-docs: markdown-check spell-check -markdown-lint: +# Markdown +markdown-check: _ensure-npx #!/usr/bin/env bash set -euo pipefail files=() @@ -155,11 +244,27 @@ markdown-lint: files+=("$file") done < <(git ls-files -z '*.md') if [ "${#files[@]}" -gt 0 ]; then + printf '%s\0' "${files[@]}" | xargs -0 -n100 npx markdownlint --config .markdownlint.json + else + echo "No markdown files found to check." + fi + +markdown-fix: _ensure-npx + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.md') + if [ "${#files[@]}" -gt 0 ]; then + echo "๐Ÿ“ markdownlint --fix (${#files[@]} files)" printf '%s\0' "${files[@]}" | xargs -0 -n100 npx markdownlint --config .markdownlint.json --fix else - echo "No markdown files found to lint." + echo "No markdown files found to format." fi +markdown-lint: markdown-check + # Plot: generate a single time-vs-dimension SVG from Criterion results. plot-vs-linalg metric="lu_solve" stat="median" sample="new" log_y="false": python-sync #!/usr/bin/env bash @@ -181,14 +286,161 @@ plot-vs-linalg-readme metric="lu_solve" stat="median" sample="new" log_y="false" uv run criterion-dim-plot "${args[@]}" # Python tooling (uv) -python-lint: python-sync +python-check: python-typecheck + uv run ruff format --check scripts/ + uv run ruff check scripts/ + +python-fix: python-sync uv run ruff check scripts/ --fix uv run ruff format scripts/ - uv run mypy scripts/criterion_dim_plot.py + +python-lint: python-check python-sync: _ensure-uv uv sync --group dev +python-typecheck: python-sync + uv run mypy scripts/criterion_dim_plot.py + +# Setup +setup: setup-tools + #!/usr/bin/env bash + set -euo pipefail + echo "Setting up la-stack development environment..." + echo "Note: Rust toolchain and components managed by rust-toolchain.toml (if present)" + echo "" + + echo "Installing Python tooling..." + uv sync --group dev + echo "" + + echo "Building project..." + cargo build + echo "โœ… Setup complete! Run 'just help-workflows' to see available commands." + +# Development tooling installation (best-effort) +setup-tools: + #!/usr/bin/env bash + set -euo pipefail + + echo "๐Ÿ”ง Ensuring tooling required by just recipes is installed..." + echo "" + + os="$(uname -s || true)" + + have() { command -v "$1" >/dev/null 2>&1; } + + install_with_brew() { + local formula="$1" + if brew list --versions "$formula" >/dev/null 2>&1; then + echo " โœ“ $formula (brew)" + else + echo " โณ Installing $formula (brew)..." + HOMEBREW_NO_AUTO_UPDATE=1 brew install "$formula" + fi + } + + if have brew; then + echo "Using Homebrew (brew) to install missing tools..." + install_with_brew uv + install_with_brew jq + install_with_brew taplo + install_with_brew yamllint + install_with_brew shfmt + install_with_brew shellcheck + install_with_brew actionlint + install_with_brew node + echo "" + else + echo "โš ๏ธ 'brew' not found; skipping automatic tool installation." + if [[ "$os" == "Darwin" ]]; then + echo "Install Homebrew from https://brew.sh (recommended), or install the following tools manually:" + else + echo "Install the following tools via your system package manager:" + fi + echo " uv, jq, taplo, yamllint, shfmt, shellcheck, actionlint, node+npx" + echo "" + fi + + echo "Ensuring Rust toolchain + components..." + if ! have rustup; then + echo "โŒ 'rustup' not found. Install Rust via https://rustup.rs and re-run: just setup-tools" + exit 1 + fi + rustup component add clippy rustfmt rust-docs rust-src + echo "" + + echo "Ensuring cargo tools..." + if ! have samply; then + echo " โณ Installing samply (cargo)..." + cargo install --locked samply + else + echo " โœ“ samply" + fi + + if ! have cargo-tarpaulin; then + if [[ "$os" == "Linux" ]]; then + echo " โณ Installing cargo-tarpaulin (cargo)..." + cargo install --locked cargo-tarpaulin + else + echo " โš ๏ธ Skipping cargo-tarpaulin install on $os (coverage is typically Linux-only)" + fi + else + echo " โœ“ cargo-tarpaulin" + fi + + echo "" + echo "Verifying required commands are available..." + missing=0 + for cmd in uv jq taplo yamllint shfmt shellcheck actionlint node npx; do + if have "$cmd"; then + echo " โœ“ $cmd" + else + echo " โœ— $cmd" + missing=1 + fi + done + if [ "$missing" -ne 0 ]; then + echo "" + echo "โŒ Some required tools are still missing." + echo "Fix the installs above and re-run: just setup-tools" + exit 1 + fi + + echo "" + echo "โœ… Tooling setup complete." + +# Shell scripts +shell-check: _ensure-shellcheck _ensure-shfmt + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.sh') + if [ "${#files[@]}" -gt 0 ]; then + printf '%s\0' "${files[@]}" | xargs -0 -n4 shellcheck -x + printf '%s\0' "${files[@]}" | xargs -0 shfmt -d + else + echo "No shell files found to check." + fi + +shell-fmt: _ensure-shfmt + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.sh') + if [ "${#files[@]}" -gt 0 ]; then + echo "๐Ÿงน shfmt -w (${#files[@]} files)" + printf '%s\0' "${files[@]}" | xargs -0 -n1 shfmt -w + else + echo "No shell files found to format." + fi + +shell-lint: shell-check + # Spell check (cspell) # # Requires either: @@ -207,11 +459,8 @@ spell-check: exit 1 fi -# Testing (delaunay-style split) -# - test: lib + doc tests (fast) -# - test-all: all tests (Rust + Python) -# - test-integration: tests/ (if present) - +# Testing +# test: runs only lib and doc tests (fast) test: cargo test --lib --verbose cargo test --doc --verbose @@ -225,8 +474,48 @@ test-integration: test-python: python-sync uv run pytest -q +# TOML +toml-fmt: _ensure-taplo + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.toml') + if [ "${#files[@]}" -gt 0 ]; then + taplo fmt "${files[@]}" + else + echo "No TOML files found to format." + fi + +toml-fmt-check: _ensure-taplo + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.toml') + if [ "${#files[@]}" -gt 0 ]; then + taplo fmt --check "${files[@]}" + else + echo "No TOML files found to check." + fi + +toml-lint: _ensure-taplo + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.toml') + if [ "${#files[@]}" -gt 0 ]; then + taplo lint "${files[@]}" + else + echo "No TOML files found to lint." + fi + # File validation -validate-json: +validate-json: _ensure-jq #!/usr/bin/env bash set -euo pipefail files=() @@ -238,3 +527,51 @@ validate-json: else echo "No JSON files found to validate." fi + +# YAML +yaml-fix: _ensure-prettier-or-npx + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.yml' '*.yaml') + if [ "${#files[@]}" -gt 0 ]; then + echo "๐Ÿ“ prettier --write (YAML, ${#files[@]} files)" + + cmd=() + if command -v prettier >/dev/null; then + cmd=(prettier --write --print-width 120) + elif command -v npx >/dev/null; then + # Prefer non-interactive installs when supported (newer npm/npx). + # NOTE: With `set -u`, expanding an empty array like "${arr[@]}" can error on older bash. + cmd=(npx) + if npx --help 2>&1 | grep -q -- '--yes'; then + cmd+=(--yes) + fi + cmd+=(prettier --write --print-width 120) + else + echo "โŒ 'prettier' not found. Install via npm (recommended): npm i -g prettier" + echo " Or install Node.js (for npx): https://nodejs.org" + exit 1 + fi + + # Use CLI flags instead of a repo-wide prettier config: keeps the scope to YAML only. + printf '%s\0' "${files[@]}" | xargs -0 -n100 "${cmd[@]}" + else + echo "No YAML files found to format." + fi + +yaml-lint: _ensure-yamllint + #!/usr/bin/env bash + set -euo pipefail + files=() + while IFS= read -r -d '' file; do + files+=("$file") + done < <(git ls-files -z '*.yml' '*.yaml') + if [ "${#files[@]}" -gt 0 ]; then + echo "๐Ÿ” yamllint (${#files[@]} files)" + yamllint --strict -c .yamllint "${files[@]}" + else + echo "No YAML files found to lint." + fi diff --git a/pyproject.toml b/pyproject.toml index ded8543..5ab9ed2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=65.0.0", "wheel"] +requires = [ "setuptools>=65.0.0", "wheel" ] build-backend = "setuptools.build_meta" [project] @@ -8,11 +8,11 @@ version = "0.1.0" description = "Python utility scripts for the la-stack Rust library" readme = "README.md" requires-python = ">=3.11" -license = {text = "BSD-3-Clause"} +license = { text = "BSD-3-Clause" } authors = [ - {name = "Adam Getchell", email = "adam@adamgetchell.org"}, + { name = "Adam Getchell", email = "adam@adamgetchell.org" }, ] -keywords = ["linear-algebra", "benchmarking", "utilities", "la-stack"] +keywords = [ "linear-algebra", "benchmarking", "utilities", "la-stack" ] classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -28,7 +28,7 @@ classifiers = [ ] # No runtime dependencies currently; scripts rely on the standard library. -dependencies = [] +dependencies = [ ] [project.urls] "Homepage" = "https://github.com/acgetchell/la-stack" @@ -41,43 +41,43 @@ criterion-dim-plot = "criterion_dim_plot:main" # Configure setuptools to find modules in scripts/ directory. [tool.setuptools] -package-dir = {"" = "scripts"} -py-modules = ["criterion_dim_plot"] +package-dir = { "" = "scripts" } +py-modules = [ "criterion_dim_plot" ] [tool.ruff] line-length = 150 target-version = "py311" -src = ["scripts"] +src = [ "scripts" ] [tool.ruff.lint] -select = ["E", "F", "W", "I", "N", "UP", "YTT", "S", "BLE", "FBT", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "EXE", "ISC", "ICN", "G", "INP", "PIE", "T20", "PYI", "PT", "Q", "RSE", "RET", "SLF", "SIM", "TID", "TCH", "ARG", "PTH", "ERA", "PD", "PGH", "PL", "TRY", "NPY", "RUF"] -fixable = ["ALL"] -unfixable = [] +select = [ "E", "F", "W", "I", "N", "UP", "YTT", "S", "BLE", "FBT", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "EXE", "ISC", "ICN", "G", "INP", "PIE", "T20", "PYI", "PT", "Q", "RSE", "RET", "SLF", "SIM", "TID", "TCH", "ARG", "PTH", "ERA", "PD", "PGH", "PL", "TRY", "NPY", "RUF" ] +fixable = [ "ALL" ] +unfixable = [ ] ignore = [ # Formatter conflicts - "COM812", # Trailing comma missing (conflicts with formatter) - "ISC001", # Implicitly concatenated string literals (conflicts with formatter) + "COM812", # Trailing comma missing (conflicts with formatter) + "ISC001", # Implicitly concatenated string literals (conflicts with formatter) # CLI script patterns "PLR2004", # Magic value used in comparison - OK for CLI constants and thresholds - "FBT001", # Boolean-typed positional argument - OK for CLI flags - "FBT002", # Boolean default positional argument - OK for CLI flags - "BLE001", # Do not catch blind exception - OK for CLI robustness - "T201", # print found - OK for CLI output - "TRY300", # Consider moving statement to else block - OK for CLI control flow - "TRY301", # Abstract raise to inner function - OK for straightforward CLI error handling - "ARG001", # Unused function argument - common in callbacks - "ERA001", # Found commented-out code - OK for explanatory comments - "EXE001", # Shebang present but file not executable - handled by packaging - "PTH123", # open() should be replaced by Path.open() - OK for some call sites - "EM102", # Exception must not use f-string - OK for CLI error messages - "TRY003", # Avoid specifying long messages outside exception class - OK for CLI reporting + "FBT001", # Boolean-typed positional argument - OK for CLI flags + "FBT002", # Boolean default positional argument - OK for CLI flags + "BLE001", # Do not catch blind exception - OK for CLI robustness + "T201", # print found - OK for CLI output + "TRY300", # Consider moving statement to else block - OK for CLI control flow + "TRY301", # Abstract raise to inner function - OK for straightforward CLI error handling + "ARG001", # Unused function argument - common in callbacks + "ERA001", # Found commented-out code - OK for explanatory comments + "EXE001", # Shebang present but file not executable - handled by packaging + "PTH123", # open() should be replaced by Path.open() - OK for some call sites + "EM102", # Exception must not use f-string - OK for CLI error messages + "TRY003", # Avoid specifying long messages outside exception class - OK for CLI reporting ] [tool.ruff.lint.per-file-ignores] "scripts/tests/**/*.py" = [ - "S101", # asserts are fine in tests - "SLF001", # tests may call internal helpers + "S101", # asserts are fine in tests + "SLF001", # tests may call internal helpers ] [tool.ruff.format] @@ -86,6 +86,14 @@ indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" +[tool.pytest.ini_options] +minversion = "8.0" +addopts = [ "-ra", "--strict-markers", "--strict-config", "--color=yes" ] +testpaths = [ "scripts/tests" ] +python_files = [ "test_*.py", "*_test.py" ] +python_classes = [ "Test*" ] +python_functions = [ "test_*" ] + [tool.mypy] python_version = "3.11" mypy_path = "scripts" diff --git a/rust-toolchain.toml b/rust-toolchain.toml index c0748b9..102ef90 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -4,21 +4,21 @@ channel = "1.92.0" # Essential components for development components = [ - "cargo", # Package manager - "clippy", # Linting (you use strict pedantic mode) - "rustfmt", # Code formatting (you use cargo fmt --all) - "rust-docs", # Local documentation - "rust-std", # Standard library - "rust-src", # Source code (helpful for IDEs) - "rust-analyzer", # Language server for IDE support + "cargo", # Package manager + "clippy", # Linting (you use strict pedantic mode) + "rustfmt", # Code formatting (you use cargo fmt --all) + "rust-docs", # Local documentation + "rust-std", # Standard library + "rust-src", # Source code (helpful for IDEs) + "rust-analyzer", # Language server for IDE support ] # Target the platforms you support (adjust as needed) targets = [ - "x86_64-apple-darwin", # macOS Intel - "aarch64-apple-darwin", # macOS Apple Silicon + "x86_64-apple-darwin", # macOS Intel + "aarch64-apple-darwin", # macOS Apple Silicon "x86_64-unknown-linux-gnu", # Linux - "x86_64-pc-windows-msvc", # Windows + "x86_64-pc-windows-msvc", # Windows ] # Set this toolchain as the profile default diff --git a/src/ldlt.rs b/src/ldlt.rs new file mode 100644 index 0000000..4dce828 --- /dev/null +++ b/src/ldlt.rs @@ -0,0 +1,336 @@ +//! LDLT factorization and solves. +//! +//! This module provides a stack-allocated LDLT factorization (`A = L D Lแต€`) intended for +//! symmetric positive definite (SPD) and positive semi-definite (PSD) matrices (e.g. Gram +//! matrices) without pivoting. + +use crate::LaError; +use crate::matrix::Matrix; +use crate::vector::Vector; + +/// LDLT factorization (`A = L D Lแต€`) for symmetric positive (semi)definite matrices. +/// +/// This factorization is **not** a general-purpose symmetric-indefinite LDLT (no pivoting). +/// It assumes the input matrix is symmetric and (numerically) SPD/PSD. +/// +/// # Storage +/// The factors are stored in a single [`Matrix`]: +/// - `D` is stored on the diagonal. +/// - The strict lower triangle stores the multipliers of `L`. +/// - The diagonal of `L` is implicit ones. +#[must_use] +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Ldlt { + factors: Matrix, + tol: f64, +} + +impl Ldlt { + #[inline] + pub(crate) fn factor(a: Matrix, tol: f64) -> Result { + debug_assert!(tol >= 0.0, "tol must be non-negative"); + + #[cfg(debug_assertions)] + debug_assert_symmetric(&a); + + let mut f = a; + + // LDLT via symmetric rank-1 updates, using only the lower triangle. + for j in 0..D { + let d = f.rows[j][j]; + if !d.is_finite() { + return Err(LaError::NonFinite { pivot_col: j }); + } + if d <= tol { + return Err(LaError::Singular { pivot_col: j }); + } + + // Compute L multipliers below the diagonal in column j. + for i in (j + 1)..D { + let l = f.rows[i][j] / d; + if !l.is_finite() { + return Err(LaError::NonFinite { pivot_col: j }); + } + f.rows[i][j] = l; + } + + // Update the trailing submatrix (lower triangle): A := A - (L_col * d) * L_col^T. + for i in (j + 1)..D { + let l_i = f.rows[i][j]; + let l_i_d = l_i * d; + + for k in (j + 1)..=i { + let l_k = f.rows[k][j]; + let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]); + if !new_val.is_finite() { + return Err(LaError::NonFinite { pivot_col: j }); + } + f.rows[i][k] = new_val; + } + } + } + + Ok(Self { factors: f, tol }) + } + + /// Determinant of the original matrix. + /// + /// For SPD/PSD matrices, this is the product of the diagonal terms of `D`. + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; + /// + /// // Symmetric SPD matrix. + /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); + /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + /// + /// assert!((ldlt.det() - 8.0).abs() <= 1e-12); + /// ``` + #[inline] + #[must_use] + pub fn det(&self) -> f64 { + let mut det = 1.0; + for i in 0..D { + det *= self.factors.rows[i][i]; + } + det + } + + /// Solve `A x = b` using this LDLT factorization. + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; + /// + /// # fn main() -> Result<(), LaError> { + /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); + /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; + /// + /// let b = Vector::<2>::new([1.0, 2.0]); + /// let x = ldlt.solve_vec(b)?.into_array(); + /// + /// assert!((x[0] - (-0.125)).abs() <= 1e-12); + /// assert!((x[1] - 0.75).abs() <= 1e-12); + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// Returns [`LaError::Singular`] if a diagonal entry `d = D[i,i]` satisfies `d <= tol` + /// (non-positive or too small), where `tol` is the tolerance that was used during factorization. + /// Returns [`LaError::NonFinite`] if NaN/โˆž is detected. + #[inline] + pub fn solve_vec(&self, b: Vector) -> Result, LaError> { + let mut x = b.data; + + // Forward substitution: L y = b (L has unit diagonal). + for i in 0..D { + let mut sum = x[i]; + let row = self.factors.rows[i]; + for (j, x_j) in x.iter().enumerate().take(i) { + sum = (-row[j]).mul_add(*x_j, sum); + } + if !sum.is_finite() { + return Err(LaError::NonFinite { pivot_col: i }); + } + x[i] = sum; + } + + // Diagonal solve: D z = y. + for (i, x_i) in x.iter_mut().enumerate().take(D) { + let diag = self.factors.rows[i][i]; + if !diag.is_finite() { + return Err(LaError::NonFinite { pivot_col: i }); + } + if diag <= self.tol { + return Err(LaError::Singular { pivot_col: i }); + } + + let v = *x_i / diag; + if !v.is_finite() { + return Err(LaError::NonFinite { pivot_col: i }); + } + *x_i = v; + } + + // Back substitution: Lแต€ x = z. + for ii in 0..D { + let i = D - 1 - ii; + let mut sum = x[i]; + for (j, x_j) in x.iter().enumerate().skip(i + 1) { + sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum); + } + if !sum.is_finite() { + return Err(LaError::NonFinite { pivot_col: i }); + } + x[i] = sum; + } + + Ok(Vector::new(x)) + } +} + +#[cfg(debug_assertions)] +fn debug_assert_symmetric(a: &Matrix) { + let scale = a.inf_norm().max(1.0); + let eps = 1e-12 * scale; + + for r in 0..D { + for c in (r + 1)..D { + let diff = (a.rows[r][c] - a.rows[c][r]).abs(); + debug_assert!( + diff <= eps, + "matrix must be symmetric (diff={diff}, eps={eps}) at ({r}, {c})" + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::DEFAULT_SINGULAR_TOL; + + use core::hint::black_box; + + use approx::assert_abs_diff_eq; + use pastey::paste; + + macro_rules! gen_public_api_ldlt_identity_tests { + ($d:literal) => { + paste! { + #[test] + fn []() { + let a = Matrix::<$d>::identity(); + let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + + assert_abs_diff_eq!(ldlt.det(), 1.0, epsilon = 1e-12); + + let b_arr = { + let mut arr = [0.0f64; $d]; + let values = [1.0f64, 2.0, 3.0, 4.0, 5.0]; + for (dst, src) in arr.iter_mut().zip(values.iter()) { + *dst = *src; + } + arr + }; + let b = Vector::<$d>::new(black_box(b_arr)); + let x = ldlt.solve_vec(b).unwrap().into_array(); + + for i in 0..$d { + assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12); + } + } + } + }; + } + + gen_public_api_ldlt_identity_tests!(2); + gen_public_api_ldlt_identity_tests!(3); + gen_public_api_ldlt_identity_tests!(4); + gen_public_api_ldlt_identity_tests!(5); + + macro_rules! gen_public_api_ldlt_diagonal_tests { + ($d:literal) => { + paste! { + #[test] + fn []() { + let diag = { + let mut arr = [0.0f64; $d]; + let values = [1.0f64, 2.0, 3.0, 4.0, 5.0]; + for (dst, src) in arr.iter_mut().zip(values.iter()) { + *dst = *src; + } + arr + }; + + let mut rows = [[0.0f64; $d]; $d]; + for i in 0..$d { + rows[i][i] = diag[i]; + } + + let a = Matrix::<$d>::from_rows(black_box(rows)); + let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + + let expected_det = { + let mut acc = 1.0; + for i in 0..$d { + acc *= diag[i]; + } + acc + }; + assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-12); + + let b_arr = { + let mut arr = [0.0f64; $d]; + let values = [5.0f64, 4.0, 3.0, 2.0, 1.0]; + for (dst, src) in arr.iter_mut().zip(values.iter()) { + *dst = *src; + } + arr + }; + + let b = Vector::<$d>::new(black_box(b_arr)); + let x = ldlt.solve_vec(b).unwrap().into_array(); + + for i in 0..$d { + assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12); + } + } + } + }; + } + + gen_public_api_ldlt_diagonal_tests!(2); + gen_public_api_ldlt_diagonal_tests!(3); + gen_public_api_ldlt_diagonal_tests!(4); + gen_public_api_ldlt_diagonal_tests!(5); + + #[test] + fn solve_2x2_known_spd() { + let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]])); + let ldlt = (black_box(Ldlt::<2>::factor))(a, DEFAULT_SINGULAR_TOL).unwrap(); + + let b = Vector::<2>::new(black_box([1.0, 2.0])); + let x = ldlt.solve_vec(b).unwrap().into_array(); + + assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12); + assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12); + assert_abs_diff_eq!(ldlt.det(), 8.0, epsilon = 1e-12); + } + + #[test] + fn solve_3x3_spd_tridiagonal_smoke() { + let a = Matrix::<3>::from_rows(black_box([ + [2.0, -1.0, 0.0], + [-1.0, 2.0, -1.0], + [0.0, -1.0, 2.0], + ])); + let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + + // Choose x = 1 so b = A x is simple: [1, 0, 1]. + let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0])); + let x = ldlt.solve_vec(b).unwrap().into_array(); + + for &x_i in &x { + assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9); + } + } + + #[test] + fn singular_detected_for_degenerate_psd() { + // Rank-1 Gram-like matrix. + let a = Matrix::<2>::from_rows(black_box([[1.0, 1.0], [1.0, 1.0]])); + let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); + assert_eq!(err, LaError::Singular { pivot_col: 1 }); + } + + #[test] + fn nonfinite_detected() { + let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]); + let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err(); + assert_eq!(err, LaError::NonFinite { pivot_col: 0 }); + } +} diff --git a/src/lib.rs b/src/lib.rs index bd4b323..3210188 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,25 +29,52 @@ mod readme_doctests { /// } /// ``` fn solve_5x5_example() {} + + /// ```rust + /// use la_stack::prelude::*; + /// + /// // This matrix is symmetric positive-definite (A = L*L^T) so LDLT works without pivoting. + /// let a = Matrix::<5>::from_rows([ + /// [1.0, 1.0, 0.0, 0.0, 0.0], + /// [1.0, 2.0, 1.0, 0.0, 0.0], + /// [0.0, 1.0, 2.0, 1.0, 0.0], + /// [0.0, 0.0, 1.0, 2.0, 1.0], + /// [0.0, 0.0, 0.0, 1.0, 2.0], + /// ]); + /// + /// let det = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap().det(); + /// assert!((det - 1.0).abs() <= 1e-12); + /// ``` + fn det_5x5_ldlt_example() {} } +mod ldlt; mod lu; mod matrix; mod vector; use core::fmt; -/// Default absolute pivot tolerance used for singularity detection. +/// Default absolute threshold used for singularity/degeneracy detection. /// /// This is intentionally conservative for geometric predicates and small systems. -pub const DEFAULT_PIVOT_TOL: f64 = 1e-12; +/// +/// Conceptually, this is an absolute bound for deciding when a scalar should be treated +/// as "numerically zero" (e.g. LU pivots, LDLT diagonal entries). +pub const DEFAULT_SINGULAR_TOL: f64 = 1e-12; + +/// Default absolute pivot magnitude threshold used for LU pivot selection / singularity detection. +/// +/// This name is kept for backwards compatibility; prefer [`DEFAULT_SINGULAR_TOL`] when the +/// tolerance is not specifically about pivot selection. +pub const DEFAULT_PIVOT_TOL: f64 = DEFAULT_SINGULAR_TOL; /// Linear algebra errors. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum LaError { /// The matrix is (numerically) singular. Singular { - /// The column where a suitable pivot could not be found. + /// The factorization column/step where a suitable pivot/diagonal could not be found. pivot_col: usize, }, /// A non-finite value (NaN/โˆž) was encountered. @@ -75,6 +102,7 @@ impl fmt::Display for LaError { impl std::error::Error for LaError {} +pub use ldlt::Ldlt; pub use lu::Lu; pub use matrix::Matrix; pub use vector::Vector; @@ -82,9 +110,9 @@ pub use vector::Vector; /// Common imports for ergonomic usage. /// /// This prelude re-exports the primary types and constants: [`Matrix`], [`Vector`], [`Lu`], -/// [`LaError`], and [`DEFAULT_PIVOT_TOL`]. +/// [`Ldlt`], [`LaError`], [`DEFAULT_PIVOT_TOL`], and [`DEFAULT_SINGULAR_TOL`]. pub mod prelude { - pub use crate::{DEFAULT_PIVOT_TOL, LaError, Lu, Matrix, Vector}; + pub use crate::{DEFAULT_PIVOT_TOL, DEFAULT_SINGULAR_TOL, LaError, Ldlt, Lu, Matrix, Vector}; } #[cfg(test)] @@ -94,8 +122,9 @@ mod tests { use approx::assert_abs_diff_eq; #[test] - fn default_pivot_tol_is_expected() { - assert_abs_diff_eq!(DEFAULT_PIVOT_TOL, 1e-12, epsilon = 0.0); + fn default_singular_tol_is_expected() { + assert_abs_diff_eq!(DEFAULT_SINGULAR_TOL, 1e-12, epsilon = 0.0); + assert_abs_diff_eq!(DEFAULT_PIVOT_TOL, DEFAULT_SINGULAR_TOL, epsilon = 0.0); } #[test] @@ -128,5 +157,6 @@ mod tests { let m = Matrix::<2>::identity(); let v = Vector::<2>::new([1.0, 2.0]); let _ = m.lu(DEFAULT_PIVOT_TOL).unwrap().solve_vec(v).unwrap(); + let _ = m.ldlt(DEFAULT_SINGULAR_TOL).unwrap().solve_vec(v).unwrap(); } } diff --git a/src/lu.rs b/src/lu.rs index 92ab801..4de031b 100644 --- a/src/lu.rs +++ b/src/lu.rs @@ -102,7 +102,8 @@ impl Lu { /// ``` /// /// # Errors - /// Returns [`LaError::Singular`] if a diagonal of `U` is (numerically) zero. + /// Returns [`LaError::Singular`] if a diagonal entry of `U` satisfies `|u_ii| <= tol`, where + /// `tol` is the tolerance that was used during factorization. /// Returns [`LaError::NonFinite`] if NaN/โˆž is detected. #[inline] pub fn solve_vec(&self, b: Vector) -> Result, LaError> { diff --git a/src/matrix.rs b/src/matrix.rs index 6d7a5bb..06b57af 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,6 +1,7 @@ //! Fixed-size, stack-allocated square matrices. use crate::LaError; +use crate::ldlt::Ldlt; use crate::lu::Lu; /// Fixed-size square matrix `Dร—D`, stored inline. @@ -152,13 +153,49 @@ impl Matrix { /// ``` /// /// # Errors - /// Returns [`LaError::Singular`] if no suitable pivot (|pivot| > `tol`) exists for a column. + /// Returns [`LaError::Singular`] if, for some column `k`, the largest-magnitude candidate pivot + /// in that column satisfies `|pivot| <= tol` (so no numerically usable pivot exists). /// Returns [`LaError::NonFinite`] if NaN/โˆž is detected during factorization. #[inline] pub fn lu(self, tol: f64) -> Result, LaError> { Lu::factor(self, tol) } + /// Compute an LDLT factorization (`A = L D Lแต€`) without pivoting. + /// + /// This is intended for symmetric positive definite (SPD) and positive semi-definite (PSD) + /// matrices such as Gram matrices. + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; + /// + /// # fn main() -> Result<(), LaError> { + /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); + /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; + /// + /// // det(A) = 8 + /// assert!((ldlt.det() - 8.0).abs() <= 1e-12); + /// + /// // Solve A x = b + /// let b = Vector::<2>::new([1.0, 2.0]); + /// let x = ldlt.solve_vec(b)?.into_array(); + /// assert!((x[0] - (-0.125)).abs() <= 1e-12); + /// assert!((x[1] - 0.75).abs() <= 1e-12); + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// Returns [`LaError::Singular`] if, for some step `k`, the required diagonal entry `d = D[k,k]` + /// is `<= tol` (non-positive or too small). This treats PSD degeneracy (and indefinite inputs) + /// as singular/degenerate. + /// Returns [`LaError::NonFinite`] if NaN/โˆž is detected during factorization. + #[inline] + pub fn ldlt(self, tol: f64) -> Result, LaError> { + Ldlt::factor(self, tol) + } + /// Determinant computed via LU decomposition. /// /// # Examples diff --git a/tests/proptest_factorizations.rs b/tests/proptest_factorizations.rs new file mode 100644 index 0000000..e26d292 --- /dev/null +++ b/tests/proptest_factorizations.rs @@ -0,0 +1,275 @@ +//! Property-based tests for LU/LDLT factorization APIs. +//! +//! These tests construct matrices from known factors so we have a reliable oracle for +//! determinant and solve behavior. + +use approx::assert_abs_diff_eq; +use pastey::paste; +use proptest::prelude::*; + +use la_stack::prelude::*; + +fn small_f64() -> impl Strategy { + (-1000i16..=1000i16).prop_map(|x| f64::from(x) / 10.0) +} + +fn small_factor_entry() -> impl Strategy { + // Keep entries small so constructed matrices are reasonably conditioned. + (-50i16..=50i16).prop_map(|x| f64::from(x) / 100.0) +} + +fn positive_diag_entry() -> impl Strategy { + // Strictly positive diagonal, comfortably above DEFAULT_SINGULAR_TOL. + (1i16..=20i16).prop_map(|x| f64::from(x) / 10.0) +} + +fn nonzero_diag_entry() -> impl Strategy { + // Strictly non-zero diagonal with a margin from 0. + prop_oneof![(-20i16..=-1i16), (1i16..=20i16)].prop_map(|x| f64::from(x) / 10.0) +} + +macro_rules! gen_factorization_proptests { + ($d:literal) => { + paste! { + proptest! { + #![proptest_config(ProptestConfig::with_cases(64))] + + #[test] + fn []( + l_raw in proptest::array::[]( + proptest::array::[](small_factor_entry()), + ), + d_diag in proptest::array::[](positive_diag_entry()), + x_true in proptest::array::[](small_f64()), + ) { + // Construct A = L * diag(D) * L^T, where L is unit-lower-triangular. + let mut l = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + l[i][j] = if i == j { + 1.0 + } else if i > j { + l_raw[i][j] + } else { + 0.0 + }; + } + } + + let mut a_rows = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..=i { + let mut sum = 0.0; + // L[j][k] is zero for k > j. + for k in 0..=j { + sum = (l[i][k] * d_diag[k]).mul_add(l[j][k], sum); + } + a_rows[i][j] = sum; + a_rows[j][i] = sum; + } + } + + let expected_det = { + let mut acc = 1.0; + for i in 0..$d { + acc *= d_diag[i]; + } + acc + }; + + let mut b_arr = [0.0f64; $d]; + for i in 0..$d { + let mut sum = 0.0; + for j in 0..$d { + sum = a_rows[i][j].mul_add(x_true[j], sum); + } + b_arr[i] = sum; + } + + let a = Matrix::<$d>::from_rows(a_rows); + let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + + assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-8); + + let b = Vector::<$d>::new(b_arr); + let x = ldlt.solve_vec(b).unwrap().into_array(); + for i in 0..$d { + assert_abs_diff_eq!(x[i], x_true[i], epsilon = 1e-8); + } + } + + #[test] + fn []( + l_raw in proptest::array::[]( + proptest::array::[](small_factor_entry()), + ), + u_raw in proptest::array::[]( + proptest::array::[](small_factor_entry()), + ), + u_diag in proptest::array::[](nonzero_diag_entry()), + x_true in proptest::array::[](small_f64()), + ) { + // Construct A = L * U, where L is unit-lower-triangular and U is upper-triangular. + let mut l = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + l[i][j] = if i == j { + 1.0 + } else if i > j { + l_raw[i][j] + } else { + 0.0 + }; + } + } + + let mut u = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + u[i][j] = if i == j { + u_diag[i] + } else if i < j { + u_raw[i][j] + } else { + 0.0 + }; + } + } + + let mut a_rows = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + let mut sum = 0.0; + // L[i][k] is zero for k > i; U[k][j] is zero for k > j. + let k_max = if i < j { i } else { j }; + for k in 0..=k_max { + sum = l[i][k].mul_add(u[k][j], sum); + } + a_rows[i][j] = sum; + } + } + + let expected_det = { + let mut acc = 1.0; + for i in 0..$d { + acc *= u_diag[i]; + } + acc + }; + + let mut b_arr = [0.0f64; $d]; + for i in 0..$d { + let mut sum = 0.0; + for j in 0..$d { + sum = a_rows[i][j].mul_add(x_true[j], sum); + } + b_arr[i] = sum; + } + + let a = Matrix::<$d>::from_rows(a_rows); + let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); + + assert_abs_diff_eq!(lu.det(), expected_det, epsilon = 1e-8); + + let b = Vector::<$d>::new(b_arr); + let x = lu.solve_vec(b).unwrap().into_array(); + for i in 0..$d { + assert_abs_diff_eq!(x[i], x_true[i], epsilon = 1e-8); + } + } + + #[test] + fn []( + l_raw in proptest::array::[]( + proptest::array::[](small_factor_entry()), + ), + u_raw in proptest::array::[]( + proptest::array::[](small_factor_entry()), + ), + u_diag in proptest::array::[](nonzero_diag_entry()), + x_true in proptest::array::[](small_f64()), + ) { + // Construct A = P^{-1} * L * U, where P swaps the first two rows. + // This ensures det(A) has an extra sign flip vs det(LU). + prop_assume!($d >= 2); + + let mut l = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + l[i][j] = if i == j { + 1.0 + } else if i > j { + l_raw[i][j] + } else { + 0.0 + }; + } + } + + let mut u = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + u[i][j] = if i == j { + u_diag[i] + } else if i < j { + u_raw[i][j] + } else { + 0.0 + }; + } + } + + let mut lu_rows = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + let mut sum = 0.0; + let k_max = if i < j { i } else { j }; + for k in 0..=k_max { + sum = l[i][k].mul_add(u[k][j], sum); + } + lu_rows[i][j] = sum; + } + } + + // Apply P^{-1}: swap rows 0 and 1. + let mut a_rows = lu_rows; + a_rows.swap(0, 1); + + let expected_det = { + let mut acc = 1.0; + for i in 0..$d { + acc *= u_diag[i]; + } + -acc + }; + + let mut b_arr = [0.0f64; $d]; + for i in 0..$d { + let mut sum = 0.0; + for j in 0..$d { + sum = a_rows[i][j].mul_add(x_true[j], sum); + } + b_arr[i] = sum; + } + + let a = Matrix::<$d>::from_rows(a_rows); + let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); + + assert_abs_diff_eq!(lu.det(), expected_det, epsilon = 1e-8); + + let b = Vector::<$d>::new(b_arr); + let x = lu.solve_vec(b).unwrap().into_array(); + for i in 0..$d { + assert_abs_diff_eq!(x[i], x_true[i], epsilon = 1e-8); + } + } + } + } + }; +} + +// Mirror delaunay-style multi-dimension tests. +gen_factorization_proptests!(2); +gen_factorization_proptests!(3); +gen_factorization_proptests!(4); +gen_factorization_proptests!(5); diff --git a/tests/proptest_matrix.rs b/tests/proptest_matrix.rs index c6795bf..338e213 100644 --- a/tests/proptest_matrix.rs +++ b/tests/proptest_matrix.rs @@ -14,6 +14,16 @@ fn small_nonzero_f64() -> impl Strategy { prop_oneof![(-1000i16..=-1i16), (1i16..=1000i16)].prop_map(|x| f64::from(x) / 10.0) } +fn small_ldlt_l_entry() -> impl Strategy { + // Keep entries small so SPD construction stays well-conditioned. + (-50i16..=50i16).prop_map(|x| f64::from(x) / 100.0) +} + +fn positive_ldlt_diag() -> impl Strategy { + // Positive diagonal, comfortably above DEFAULT_PIVOT_TOL. + (1i16..=20i16).prop_map(|x| f64::from(x) / 10.0) +} + macro_rules! gen_public_api_matrix_proptests { ($d:literal) => { paste! { @@ -98,6 +108,65 @@ macro_rules! gen_public_api_matrix_proptests { assert_abs_diff_eq!(x[i], expected_x, epsilon = 1e-12); } } + + #[test] + fn []( + l_raw in proptest::array::[]( + proptest::array::[](small_ldlt_l_entry()), + ), + d_diag in proptest::array::[](positive_ldlt_diag()), + x_true in proptest::array::[](small_f64()), + ) { + // Construct an SPD matrix A = L * diag(D) * L^T, where L is unit-lower-triangular + // and D has strictly positive entries. + let mut l = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..$d { + l[i][j] = if i == j { + 1.0 + } else if i > j { + l_raw[i][j] + } else { + 0.0 + }; + } + } + + let mut a_rows = [[0.0f64; $d]; $d]; + for i in 0..$d { + for j in 0..=i { + let mut sum = 0.0; + for k in 0..=j { + sum = (l[i][k] * d_diag[k]).mul_add(l[j][k], sum); + } + a_rows[i][j] = sum; + a_rows[j][i] = sum; + } + } + + let mut b_arr = [0.0f64; $d]; + for i in 0..$d { + let mut sum = 0.0; + for j in 0..$d { + sum = a_rows[i][j].mul_add(x_true[j], sum); + } + b_arr[i] = sum; + } + + let a = Matrix::<$d>::from_rows(a_rows); + let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + + let det_ldlt = ldlt.det(); + let det_lu = a.det(DEFAULT_PIVOT_TOL).unwrap(); + assert_abs_diff_eq!(det_ldlt, det_lu, epsilon = 1e-8); + + let b = Vector::<$d>::new(b_arr); + let x = ldlt.solve_vec(b).unwrap().into_array(); + + for i in 0..$d { + assert_abs_diff_eq!(x[i], x_true[i], epsilon = 1e-8); + } + } } } };